Skip to content

Handle swizzle1d in isResharding#6028

Open
Priya2698 wants to merge 1 commit intomainfrom
pm/swizzle1d_resharding
Open

Handle swizzle1d in isResharding#6028
Priya2698 wants to merge 1 commit intomainfrom
pm/swizzle1d_resharding

Conversation

@Priya2698
Copy link
Collaborator

No description provided.

@Priya2698
Copy link
Collaborator Author

!test

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR extends isResharding / haveDifferentShardings to correctly handle Swizzle1D transforms, which previously caused an NVF_THROW("Unexpected transform") crash. The fix threads a new pt_to_index map (device-parallel-type → symbolic index Val*) through computeLoopIndex and applies the correct inverse swizzle formula out_idx = (in_idx - pt_val + extent) % extent (the mathematical inverse of the evaluator's forward formula in_idx = (out_idx + pt_val) % extent).

Key changes:

  • computeLoopIndex: gains a pt_to_index parameter; the else-if chain is refactored to if/continue blocks with the new Swizzle1D branch appended.
  • haveDifferentShardings: creates pt_to_index by iterating over kParallelTypeDIDs present in the producer's mesh; this map is also forwarded to the consumer's computeLoopIndex call. If the consumer has a Swizzle1D using a DID type absent from the producer's mesh (possible when meshes differ but only Stream is in parallel_types, bypassing the early-exit guard), pt_to_index.at(...) will throw std::out_of_range.
  • Tests: two new tests — Swizzle1D_DIDToStream (producer DIDx-sharded, consumer stream-via-swizzle → resharding) and Swizzle1D_ConsistentSwizzle (same swizzle on both sides → non-resharding). The consistent-swizzle test is missing a DIDx-check assertion.

Confidence Score: 3/5

  • Mostly safe for the common same-mesh case; a potential out_of_range crash exists for cross-mesh Swizzle1D when only Stream is checked.
  • The core inverse-swizzle formula is mathematically correct and the two test cases validate the primary use cases. However, pt_to_index is built solely from the producer's mesh, which can cause an unguarded std::out_of_range exception when a consumer's Swizzle1D references a DID type not present in the producer's mesh — a scenario that the existing DID-type early-exit guard does not cover when only ParallelType::Stream is in parallel_types. Additionally, the consistent-swizzle test omits a DIDx assertion that mirrors the pattern in the DID-to-Stream test.
  • csrc/multidevice/resharding.cpp — the pt_to_index construction (lines 266–278) should include the consumer's mesh to avoid potential out_of_range when cross-mesh Swizzle1D scenarios are encountered.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["haveDifferentShardings(producer, consumer, parallel_types)"] --> B{CPU scalar?}
    B -- yes --> C[return false]
    B -- no --> D{Different meshes\n& DID type present?}
    D -- yes --> E[return true]
    D -- no --> F["Build pt_to_index\nfrom producer's mesh only\n(for Swizzle1D)"]
    F --> G[Create symbolic indices\nfor producer logical IDs\n& consumer root IDs]
    G --> H["computeLoopIndex(p_id, kLogical, pt_to_index)"]
    H --> H1{Transform type?}
    H1 -- Split --> H2[div/mod]
    H1 -- Merge --> H3[add/mul]
    H1 -- Swizzle1D --> H4["out_idx = (in_idx - pt_val + extent) % extent\npt_val = pt_to_index.at(swizzle->parallelType())"]
    H1 -- other --> H5[NVF_THROW]
    H4 --> I["computeLoopIndex(c_id, kRoot, pt_to_index)"]
    I --> J{p_index == c_index\nunder assumptions?}
    J -- yes --> K[return false: non-resharding]
    J -- no --> L[return true: resharding]

    style H4 fill:#fff3cd,stroke:#ffc107
    style F fill:#fff3cd,stroke:#ffc107
Loading

Last reviewed commit: 57785eb

Comment on lines +266 to +278
std::unordered_map<ParallelType, Val*> pt_to_index;
const DeviceMesh& mesh = producer->getDeviceMesh();
for (ParallelType pt : kParallelTypeDIDs) {
if (!mesh.hasParallelType(pt)) {
continue;
}
Val* device_idx = IrBuilder::create<Val>(DataType::Index);
pt_to_index[pt] = device_idx;
Val* team_size = IrBuilder::create<Val>(mesh.size(pt), DataType::Index);
assumptions.push_back(
SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx));
assumptions.push_back(SimplifyingIrBuilder::ltExpr(device_idx, team_size));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pt_to_index is populated exclusively from the producer's mesh. It is then passed verbatim into computeLoopIndex for the consumer side as well (lines 358–362). This means that if haveDifferentShardings is called with parallel_types = {Stream} and the consumer holds a Swizzle1D referencing a ParallelType (e.g. DIDy) that is not present in the producer's mesh, pt_to_index.at(swizzle->parallelType()) (line 110) will throw std::out_of_range.

The early-exit guard at lines 152–158 only fires when parallel_types contains a DID type, so it does not protect this path when only ParallelType::Stream is being checked.

A minimal fix is to also fold the consumer's mesh into pt_to_index:

Suggested change
std::unordered_map<ParallelType, Val*> pt_to_index;
const DeviceMesh& mesh = producer->getDeviceMesh();
for (ParallelType pt : kParallelTypeDIDs) {
if (!mesh.hasParallelType(pt)) {
continue;
}
Val* device_idx = IrBuilder::create<Val>(DataType::Index);
pt_to_index[pt] = device_idx;
Val* team_size = IrBuilder::create<Val>(mesh.size(pt), DataType::Index);
assumptions.push_back(
SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx));
assumptions.push_back(SimplifyingIrBuilder::ltExpr(device_idx, team_size));
}
// Collect device-parallel symbolic indices from both meshes.
for (const DeviceMesh* m : {&producer->getDeviceMesh(),
&consumer->getDeviceMesh()}) {
for (ParallelType pt : kParallelTypeDIDs) {
if (!m->hasParallelType(pt) || pt_to_index.count(pt)) {
continue;
}
Val* device_idx = IrBuilder::create<Val>(DataType::Index);
pt_to_index[pt] = device_idx;
Val* team_size = IrBuilder::create<Val>(m->size(pt), DataType::Index);
assumptions.push_back(
SimplifyingIrBuilder::leExpr(fusion->zeroVal(), device_idx));
assumptions.push_back(
SimplifyingIrBuilder::ltExpr(device_idx, team_size));
}
}

Comment on lines +676 to +678
EXPECT_FALSE(haveDifferentShardings(
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream}));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Swizzle1D_ConsistentSwizzle only asserts that {ParallelType::Stream} is non-resharding. Because neither tensor's loop domain has a DIDx-parallelized axis in this configuration, it would be useful to also verify that checking {ParallelType::DIDx} does not regress (i.e. still returns false). This mirrors the two-assertion pattern used in Swizzle1D_DIDToStream and guards against future regressions where a swizzle-internal DID type is accidentally surfaced.

Suggested change
EXPECT_FALSE(haveDifferentShardings(
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream}));
}
EXPECT_FALSE(haveDifferentShardings(
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::Stream}));
EXPECT_FALSE(haveDifferentShardings(
in, DomainType::kLoop, out, DomainType::kLoop, {ParallelType::DIDx}));

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant