Conversation
|
!test |
Greptile SummaryThis PR extends Key changes:
Confidence Score: 3/5
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["haveDifferentShardings(producer, consumer, parallel_types)"] --> B{CPU scalar?}
B -- yes --> C[return false]
B -- no --> D{Different meshes\n& DID type present?}
D -- yes --> E[return true]
D -- no --> F["Build pt_to_index\nfrom producer's mesh only\n(for Swizzle1D)"]
F --> G[Create symbolic indices\nfor producer logical IDs\n& consumer root IDs]
G --> H["computeLoopIndex(p_id, kLogical, pt_to_index)"]
H --> H1{Transform type?}
H1 -- Split --> H2[div/mod]
H1 -- Merge --> H3[add/mul]
H1 -- Swizzle1D --> H4["out_idx = (in_idx - pt_val + extent) % extent\npt_val = pt_to_index.at(swizzle->parallelType())"]
H1 -- other --> H5[NVF_THROW]
H4 --> I["computeLoopIndex(c_id, kRoot, pt_to_index)"]
I --> J{p_index == c_index\nunder assumptions?}
J -- yes --> K[return false: non-resharding]
J -- no --> L[return true: resharding]
style H4 fill:#fff3cd,stroke:#ffc107
style F fill:#fff3cd,stroke:#ffc107
Last reviewed commit: 57785eb |
| std::unordered_map<ParallelType, Val*> pt_to_index; | ||
| const DeviceMesh& mesh = producer->getDeviceMesh(); | ||
| for (ParallelType pt : kParallelTypeDIDs) { | ||
| if (!mesh.hasParallelType(pt)) { | ||
| continue; | ||
| } | ||
| Val* device_idx = IrBuilder::create<Val>(DataType::Index); | ||
| pt_to_index[pt] = device_idx; | ||
| Val* team_size = IrBuilder::create<Val>(mesh.size(pt), DataType::Index); | ||
| assumptions.push_back( | ||
| SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); | ||
| assumptions.push_back(SimplifyingIrBuilder::ltExpr(device_idx, team_size)); | ||
| } |
There was a problem hiding this comment.
pt_to_index is populated exclusively from the producer's mesh. It is then passed verbatim into computeLoopIndex for the consumer side as well (lines 358–362). This means that if haveDifferentShardings is called with parallel_types = {Stream} and the consumer holds a Swizzle1D referencing a ParallelType (e.g. DIDy) that is not present in the producer's mesh, pt_to_index.at(swizzle->parallelType()) (line 110) will throw std::out_of_range.
The early-exit guard at lines 152–158 only fires when parallel_types contains a DID type, so it does not protect this path when only ParallelType::Stream is being checked.
A minimal fix is to also fold the consumer's mesh into pt_to_index:
| std::unordered_map<ParallelType, Val*> pt_to_index; | |
| const DeviceMesh& mesh = producer->getDeviceMesh(); | |
| for (ParallelType pt : kParallelTypeDIDs) { | |
| if (!mesh.hasParallelType(pt)) { | |
| continue; | |
| } | |
| Val* device_idx = IrBuilder::create<Val>(DataType::Index); | |
| pt_to_index[pt] = device_idx; | |
| Val* team_size = IrBuilder::create<Val>(mesh.size(pt), DataType::Index); | |
| assumptions.push_back( | |
| SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); | |
| assumptions.push_back(SimplifyingIrBuilder::ltExpr(device_idx, team_size)); | |
| } | |
| // Collect device-parallel symbolic indices from both meshes. | |
| for (const DeviceMesh* m : {&producer->getDeviceMesh(), | |
| &consumer->getDeviceMesh()}) { | |
| for (ParallelType pt : kParallelTypeDIDs) { | |
| if (!m->hasParallelType(pt) || pt_to_index.count(pt)) { | |
| continue; | |
| } | |
| Val* device_idx = IrBuilder::create<Val>(DataType::Index); | |
| pt_to_index[pt] = device_idx; | |
| Val* team_size = IrBuilder::create<Val>(m->size(pt), DataType::Index); | |
| assumptions.push_back( | |
| SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx)); | |
| assumptions.push_back( | |
| SimplifyingIrBuilder::ltExpr(device_idx, team_size)); | |
| } | |
| } |
| EXPECT_FALSE(haveDifferentShardings( | ||
| in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); | ||
| } |
There was a problem hiding this comment.
Swizzle1D_ConsistentSwizzle only asserts that {ParallelType::Stream} is non-resharding. Because neither tensor's loop domain has a DIDx-parallelized axis in this configuration, it would be useful to also verify that checking {ParallelType::DIDx} does not regress (i.e. still returns false). This mirrors the two-assertion pattern used in Swizzle1D_DIDToStream and guards against future regressions where a swizzle-internal DID type is accidentally surfaced.
| EXPECT_FALSE(haveDifferentShardings( | |
| in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); | |
| } | |
| EXPECT_FALSE(haveDifferentShardings( | |
| in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream})); | |
| EXPECT_FALSE(haveDifferentShardings( | |
| in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::DIDx})); |
No description provided.