Skip to content

Commit e9262ad

Browse files
authored
Introduce unstructured-to-memref pass (#216)
This PR introduces the `unstructured-to-memref` pass responsible for converting unstructured triton load / store ops to memref load / store ops. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. The pass is intended to be used after running `--fold-unstructured-ptr`. Triton load op (gather) is lowered to a `linalg.generic` whose body contains a load from the offset indicated by the offset provided by `tts.make_unstructured_tptr`. For load op with mask, an inner-most `scf.if` is used to return a default value (or the `other` in `tt.load` if provided) if the corresponding mask value is false. Example of a load: ```mlir func.func @gather_simple_mask_with_other(%arg0: memref<*xf32>, %arg1: memref<*xf32>) { %cst = arith.constant -1.000000e+00 : f32 %cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32> %load_tensor = bufferization.to_tensor %cast restrict : memref<?xf32> %out = tensor.empty() : tensor<64xf32> %gather = linalg.generic { iterator_types = ["parallel"] } ins(%offset_tensor, %mask_tensor : tensor<64xi32>, tensor<64xi1>) outs(%out : tensor<64xf32>) { ^bb0(%offset: i32, %mask: i1, %out: f32): %yield = scf.if %mask -> (f32) { %index = arith.index_cast %offset : i32 to index %extracted = tensor.extract %load_tensor[%index] : tensor<?xf32> scf.yield %extracted : f32 } else { scf.yield %cst : f32 } linalg.yield %yield : f32 } -> tensor<64xf32> ``` Triton store op (scatter) is lowered to an `affine.for` loop nest that stores the value to the appropriate offset provided by `tts.make_unstructured_tptr`. Store op with mask is also supported. Example of a store: ```mlir func.func @masked_gather_scatter(%arg0: memref<*xf32>, %arg1: memref<*xf32>) { %store_memref = memref.cast %arg1 : memref<*xf32> to memref<?xf32> affine.for %i = 0 to 4 { %mask_val = tensor.extract %mask[%i] : tensor<4xi1> scf.if %mask_val { %offset_val = tensor.extract %offset_tensor[%i] : tensor<4xi32> %store_value = tensor.extract %tensor[%i] : tensor<4xf32> %offset_index = arith.index_cast %offset_val : i32 to index memref.store %store_value, %store_memref[%offset_index] : memref<?xf32> } } ``` --- # Intended lowering pipeline - triton-to-structured (no changes): - analyzes structured addptr sequences - introduces `tts.make_tptr %ptr_arg with offsets and strides` - introduces `tts.load` and `tts.store` - leaves unstructured addptr sequences and their corresponding `tt.load` and `tt.store` intact - triton-to-unstructured (#210): - introduces `tts.gather` and `tts.scatter` - removes all pointer-producing ops such as `tt.addptr` and `tt.splat` and replaces them with offset-producing ops - structured-to-memref (#217): - currently converts everything to memref including scalar addptr and kernel arguments - will change to just convert ops in the `tts` dialect to `memref` with the exception of `tts.gather` and `tts.scatter` - unstructured-to-memref (#216): - converts the remaining unstructured `tts.gather`, `tts.scatter` into memref - triton-ptr-to-memref (#211): - converts kernel arguments with pointer type to memref
1 parent 91ac8d8 commit e9262ad

File tree

13 files changed

+824
-1
lines changed

13 files changed

+824
-1
lines changed

include/triton-shared/Conversion/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_subdirectory(TritonToLinalg)
22
add_subdirectory(TritonToLinalgExperimental)
33
add_subdirectory(TritonToStructured)
44
add_subdirectory(TritonArithToLinalg)
5+
add_subdirectory(TritonPtrToMemref)
56
add_subdirectory(TritonToUnstructured)
67
add_subdirectory(StructuredToMemref)
7-
add_subdirectory(TritonPtrToMemref)
8+
add_subdirectory(UnstructuredToMemref)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#===------------------------------------------------------------------------===#
2+
#
3+
# Copyright (c) Microsoft Corporation.
4+
# Licensed under the MIT license.
5+
#
6+
#===------------------------------------------------------------------------===#
7+
8+
set(LLVM_TARGET_DEFINITIONS Passes.td)
9+
mlir_tablegen(Passes.h.inc -gen-pass-decls --name UnstructuredToMemref)
10+
add_public_tablegen_target(UnstructuredToMemrefConversionPassIncGen)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Copyright (c) Microsoft Corporation.
4+
// Licensed under the MIT license.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#ifndef UNSTRUCTURED_TO_MEMREF_CONVERSION_PASSES_H
9+
#define UNSTRUCTURED_TO_MEMREF_CONVERSION_PASSES_H
10+
11+
#include "triton-shared/Conversion/UnstructuredToMemref/UnstructuredToMemref.h"
12+
13+
namespace mlir {
14+
namespace triton {
15+
16+
#define GEN_PASS_REGISTRATION
17+
#include "triton-shared/Conversion/UnstructuredToMemref/Passes.h.inc"
18+
19+
} // namespace triton
20+
} // namespace mlir
21+
22+
#endif
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Copyright (c) Microsoft Corporation.
4+
// Licensed under the MIT license.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#ifndef UNSTRUCTURED_TO_MEMREF_CONVERSION_PASSES
9+
#define UNSTRUCTURED_TO_MEMREF_CONVERSION_PASSES
10+
11+
include "mlir/Pass/PassBase.td"
12+
13+
def UnstructuredToMemref : Pass<"unstructured-to-memref", "mlir::ModuleOp"> {
14+
let summary = "Convert unstructured triton ptr (gather / scatter) to memref";
15+
let constructor = "triton::createUnstructuredToMemrefPass()";
16+
}
17+
18+
#endif
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Copyright (c) Microsoft Corporation.
4+
// Licensed under the MIT license.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#ifndef TRITON_CONVERSION_UNSTRUCTUREDTOMEMREF_UNSTRUCTUREDTOMEMREF_H
9+
#define TRITON_CONVERSION_UNSTRUCTUREDTOMEMREF_UNSTRUCTUREDTOMEMREF_H
10+
11+
#include "mlir/Pass/Pass.h"
12+
13+
namespace mlir {
14+
namespace triton {
15+
16+
std::unique_ptr<OperationPass<ModuleOp>> createUnstructuredToMemrefPass();
17+
18+
} // namespace triton
19+
} // namespace mlir
20+
21+
#endif // TRITON_CONVERSION_UNSTRUCTUREDTOMEMREF_UNSTRUCTUREDTOMEMREF_H

lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ add_subdirectory(TritonToUnstructured)
55
add_subdirectory(TritonArithToLinalg)
66
add_subdirectory(StructuredToMemref)
77
add_subdirectory(TritonPtrToMemref)
8+
add_subdirectory(UnstructuredToMemref)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#===------------------------------------------------------------------------===#
2+
#
3+
# Copyright (c) Microsoft Corporation.
4+
# Licensed under the MIT license.
5+
#
6+
#===------------------------------------------------------------------------===#
7+
8+
add_triton_library(UnstructuredToMemref
9+
UnstructuredToMemrefPass.cpp
10+
11+
DEPENDS
12+
UnstructuredToMemrefConversionPassIncGen
13+
14+
LINK_LIBS PUBLIC
15+
TritonTilingExtIR
16+
MLIRArithDialect
17+
MLIRDialectUtils
18+
MLIRIR
19+
MLIRMathDialect
20+
MLIRPass
21+
MLIRTensorDialect
22+
MLIRTransforms
23+
MLIRSupport
24+
TritonAnalysis
25+
TritonIR
26+
TritonTransforms
27+
TritonSharedAnalysis
28+
)

0 commit comments

Comments
 (0)