Skip to content

Conversation

@Nullkooland
Copy link
Contributor

As discussed in #216, with upstream fix llvm/llvm-project#114045 of linalg op now implements RecursiveMemoryEffects trait, we can now convert tts.scatter to linalg.generic with body of memref.store on each scalar index and value element.

For instance, triton_shared/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir:

// RUN: triton-shared-opt --triton-to-unstructured --canonicalize --unstructured-to-memref --canonicalize %s

#map = affine_map<(d0) -> (d0)>
module {
  tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<3> : tensor<4xi32>
    %cst_0 = arith.constant dense<64> : tensor<4xi32>
    %cst_1 = arith.constant dense<4> : tensor<4xi32>
    %c2_i32 = arith.constant 2 : i32
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %cst_2 = arith.constant 9.900000e+01 : f32
    %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
    %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
    %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
    %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>)  : i32 {
      %4 = arith.divsi %arg3, %cst : tensor<4xi32>
      %5 = tt.splat %arg2 : i32 -> tensor<4xi32>
      %6 = arith.addi %4, %5 : tensor<4xi32>
      %7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32>
      %cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
      %8 = bufferization.to_tensor %cast restrict : memref<?xf32> to tensor<?xf32>
      %9 = tensor.empty() : tensor<4xf32>
      %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
      ^bb0(%in: i32, %in_4: i1, %out: f32):
        %13 = scf.if %in_4 -> (f32) {
          %14 = arith.index_cast %in : i32 to index
          %extracted = tensor.extract %8[%14] : tensor<?xf32>
          scf.yield %extracted : f32
        } else {
          scf.yield %cst_2 : f32
        }
        linalg.yield %13 : f32
      } -> tensor<4xf32>
      %cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
      // tts.scatter lowers to:
      linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %10, %7 : tensor<4xi32>, tensor<4xf32>, tensor<4xi1>) {
      ^bb0(%in: i32, %in_4: f32, %in_5: i1):
        scf.if %in_5 {
          %13 = arith.index_cast %in : i32 to index
          memref.store %in_4, %cast_3[%13] : memref<?xf32>
        }
        linalg.yield
      }
      %11 = arith.addi %6, %cst_1 : tensor<4xi32>
      %12 = arith.addi %arg4, %cst_1 : tensor<4xi32>
      scf.yield %11, %12 : tensor<4xi32>, tensor<4xi32>
    }
    tt.return
  }
}

We can also utilize linalg-fuse-elementwise-ops now:

// RUN: triton-shared-opt --linalg-fuse-elementwise-ops --canonicalize %s

#map = affine_map<(d0) -> (d0)>
module {
  tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<3> : tensor<4xi32>
    %cst_0 = arith.constant dense<64> : tensor<4xi32>
    %cst_1 = arith.constant dense<4> : tensor<4xi32>
    %c2_i32 = arith.constant 2 : i32
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
    %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
    %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
    %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>)  : i32 {
      %4 = arith.divsi %arg3, %cst : tensor<4xi32>
      %5 = tt.splat %arg2 : i32 -> tensor<4xi32>
      %6 = arith.addi %4, %5 : tensor<4xi32>
      %7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32>
      %cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
      %8 = bufferization.to_tensor %cast restrict : memref<?xf32> to tensor<?xf32>
      %cast_2 = memref.cast %0 : memref<*xf32> to memref<?xf32>
      linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) {
      ^bb0(%in: i32, %in_3: i1):
        scf.if %in_3 {
          %11 = arith.index_cast %in : i32 to index
          %extracted = tensor.extract %8[%11] : tensor<?xf32>
          %12 = arith.index_cast %in : i32 to index
          memref.store %extracted, %cast_2[%12] : memref<?xf32>
        }
        linalg.yield
      }
      %9 = arith.addi %6, %cst_1 : tensor<4xi32>
      %10 = arith.addi %arg4, %cst_1 : tensor<4xi32>
      scf.yield %9, %10 : tensor<4xi32>, tensor<4xi32>
    }
    tt.return
  }
}

@Nullkooland Nullkooland force-pushed the tts_scatter_to_linalg_generic_store branch from 00e7118 to 65a3e6b Compare February 13, 2025 10:19
@Nullkooland Nullkooland force-pushed the tts_scatter_to_linalg_generic_store branch from 65a3e6b to d0fe5a9 Compare February 13, 2025 10:21
Copy link
Contributor

@nhat-nguyen nhat-nguyen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you! 💯

@nhat-nguyen nhat-nguyen merged commit ccba545 into microsoft:main Feb 14, 2025
3 checks passed
@nhat-nguyen
Copy link
Contributor

@Nullkooland Thanks for the patch. By the way this implementation has a strict restriction where load/store are assumed to have a single base, and the base must come from kernel arguments. This assumption does not always hold for all triton programs. We're working on the multiple base scenarios. :)

@Nullkooland Nullkooland deleted the tts_scatter_to_linalg_generic_store branch February 17, 2025 09:56
nhat-nguyen added a commit that referenced this pull request Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants