Skip to content

Conversation

@shashforge
Copy link
Contributor

@shashforge shashforge commented Nov 13, 2025

This fixes a crash in SCF→GPU when building the per‑dim index for mapped scf.parallel.

Change:

  • Map step/lb through cloningMap, then run ensureLaunchIndependent.
  • If either is still unavailable at launch scope, emit a match‑failure; otherwise build the affine.apply.

Why this is correct:

  • Matches how the pass already handles launch bounds; avoids creating an op with invalid operands and replaces a segfault with a clear diagnostic.

Tests:

  • Added two small regressions that lower to gpu.launch and exercise the affine.apply path.

Fixes : #167654

… crash

When lowering mapped scf.parallel to gpu.launch, the per-dimension index
was built with AffineApplyOp::create using ensureLaunchIndependent(lb/step)
directly. If lb/step were not available at the launch scope, that helper
returned an empty value and the builder crashed while creating the op.

Mirror the bound-handling path: first map lb/step through cloningMap, then
call ensureLaunchIndependent. If either operand still isn’t available above
the launch, report a precise match-failure instead of crashing. This makes
conversion fail cleanly on invalid cases and succeed for valid ones.

Add two positive regressions that previously crashed and now lower to
gpu.launch (and affine.apply).

Fixes: llvm#167654
Signed-off-by: Shashi Shankar <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Shashi Shankar (shashforge)

Changes

Description
This patch fixes a crash in the SCF -> GPU conversion when building the per‑dimension index for GPU‑mapped scf.parallel loops.

What was wrong

We called AffineApplyOp::create(...) with operands coming from ensureLaunchIndependent(...).

When step or lowerBound weren’t available at the gpu.launch scope, ensureLaunchIndependent returned an empty Value.

Passing that empty value into AffineApplyOp::create caused a crash.

What I changed

Map step and lowerBound through cloningMap.lookupOrDefault(...) first.

Then run ensureLaunchIndependent(...) on the mapped values.

If either is still not available, emit a clear notifyMatchFailure and return failure() (no crash).

Otherwise, build AffineApplyOp::create with the validated operands

Fixes : #167654


Full diff: https://github.com/llvm/llvm-project/pull/167959.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp (+16-2)
  • (added) mlir/test/Conversion/SCFToGPU/parallel-to-gpu-crash-regression.mlir (+31)
  • (added) mlir/test/Conversion/SCFToGPU/parallel-to-gpu-index-creation.mlir (+24)
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 76a822b05a652..309121f520811 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -453,10 +453,24 @@ static LogicalResult processParallelLoop(
           1, 2,
           rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) +
               rewriter.getAffineSymbolExpr(1));
+      // Map through cloningMap first so we use values valid at the launch
+      // scope, then ensure they are launch-independent (or cloned constants).
+      Value mappedStep = cloningMap.lookupOrDefault(step);
+      Value mappedLowerBound = cloningMap.lookupOrDefault(lowerBound);
+
+      mappedStep = ensureLaunchIndependent(mappedStep);
+      mappedLowerBound = ensureLaunchIndependent(mappedLowerBound);
+
+      // If either cannot be made available above the launch, fail gracefully.
+      if (!mappedStep || !mappedLowerBound) {
+        return rewriter.notifyMatchFailure(
+            parallelOp, "lower bound / step must be constant or defined above "
+                        "the gpu.launch");
+      }
+
       newIndex = AffineApplyOp::create(
           rewriter, loc, annotation.getMap().compose(lowerAndStep),
-          ValueRange{operand, ensureLaunchIndependent(step),
-                     ensureLaunchIndependent(lowerBound)});
+          ValueRange{operand, mappedStep, mappedLowerBound});
       // If there was also a bound, insert that, too.
       // TODO: Check that we do not assign bounds twice.
       if (annotation.getBound()) {
diff --git a/mlir/test/Conversion/SCFToGPU/parallel-to-gpu-crash-regression.mlir b/mlir/test/Conversion/SCFToGPU/parallel-to-gpu-crash-regression.mlir
new file mode 100644
index 0000000000000..01daed159b6fd
--- /dev/null
+++ b/mlir/test/Conversion/SCFToGPU/parallel-to-gpu-crash-regression.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --convert-parallel-loops-to-gpu | FileCheck %s
+
+// Goal: exercise the per-dim index computation
+//        newIndex = hardware_id * step + lowerBound
+// and ensure we see a gpu.launch and an affine.apply (no crash).
+
+module {
+  func.func @two_dim_parallel_mapped() {
+    %c0  = arith.constant 0 : index
+    %c1  = arith.constant 1 : index
+    %c32 = arith.constant 32 : index
+
+    // Single 2‑D scf.parallel. Each dimension is mapped to a GPU dim.
+    // We *use* both IVs so the conversion must build indices.
+    scf.parallel (%bx, %tx) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) {
+      %u = arith.addi %bx, %c0 : index
+      %v = arith.addi %tx, %c0 : index
+      // No explicit terminator: the parser inserts an empty scf.reduce.
+    } {
+      mapping = [
+        #gpu.loop_dim_map<processor = block_x,  map = (d0) -> (d0), bound = (d0) -> (d0)>,
+        #gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>
+      ]
+    }
+    return
+  }
+}
+
+// CHECK-LABEL: func.func @two_dim_parallel_mapped
+// CHECK:       gpu.launch
+// CHECK:       affine.apply
diff --git a/mlir/test/Conversion/SCFToGPU/parallel-to-gpu-index-creation.mlir b/mlir/test/Conversion/SCFToGPU/parallel-to-gpu-index-creation.mlir
new file mode 100644
index 0000000000000..55e425a77c18f
--- /dev/null
+++ b/mlir/test/Conversion/SCFToGPU/parallel-to-gpu-index-creation.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s --convert-parallel-loops-to-gpu | FileCheck %s
+
+module {
+  func.func @one_dim_parallel_mapped() {
+    %c0  = arith.constant 0 : index
+    %c1  = arith.constant 1 : index
+    %c64 = arith.constant 64 : index
+
+    // 1‑D loop mapped to thread_x; use the IV to force index computation.
+    scf.parallel (%t) = (%c0) to (%c64) step (%c1) {
+      %w = arith.addi %t, %c0 : index
+      // Implicit empty scf.reduce terminator.
+    } {
+      mapping = [
+        #gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>
+      ]
+    }
+    return
+  }
+}
+
+// CHECK-LABEL: func.func @one_dim_parallel_mapped
+// CHECK:       gpu.launch
+// CHECK:       affine.apply

Copy link
Member

@linuxlonelyeagle linuxlonelyeagle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some test-related suggestions.

@shashforge shashforge force-pushed the mlir-scf2gpu-affineapply-crash branch from 388f412 to 3fab4aa Compare November 16, 2025 19:20
@shashforge
Copy link
Contributor Author

shashforge commented Nov 16, 2025

@linuxlonelyeagle @joker-eph — I’ve moved the tests to mlir/test/Conversion/SCFToGPU/parallel_loop.mlir and simplified the checks per the Testing Guide. When you have a chance, could you please take another look? Happy to iterate on any additional feedback. Thanks!

@shashforge shashforge force-pushed the mlir-scf2gpu-affineapply-crash branch from 36c8e3f to f5fbbba Compare November 16, 2025 19:45
@github-actions
Copy link

🐧 Linux x64 Test Results

  • 7080 tests passed
  • 594 tests skipped

@shashforge shashforge force-pushed the mlir-scf2gpu-affineapply-crash branch from 125c3c4 to 59235e6 Compare November 17, 2025 21:00
@shashforge shashforge requested a review from joker-eph November 17, 2025 21:02
// CHECK-LABEL: func.func @scf2gpu_index_creation_2d
// CHECK: gpu.launch
// CHECK: affine.apply

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// CHECK-LABEL: func.func @scf2gpu_index_creation_2d -> // CHECK-LABEL: func @scf2gpu_index_creation_2d

and You should capture the SSA value and then match it.rather than a simple matching operation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// CHECK-LABEL: func.func @scf2gpu_index_creation_1d
// CHECK: gpu.launch
// CHECK: affine.apply

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The end of the file should contain only a single blank line, which will not be displayed on GitHub.

Copy link
Member

@linuxlonelyeagle linuxlonelyeagle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some changes required to the formatting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MLIR] convert-parallel-loops-to-gpu crashes in processParallelLoop creating AffineApplyOp

4 participants