Skip to content

Commit cbb1cce

Browse files
authored
[Loop Specialization]: Specialize loops containing masked operations with loop invariant mask (#3586)
Version loops containing masked operations (e.g. tt.load with a mask) where the mask is loop invariant. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 30bcd8e commit cbb1cce

File tree

5 files changed

+360
-176
lines changed

5 files changed

+360
-176
lines changed

test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: triton-opt %s -triton-intel-remove-masks -triton-raise-block-pointer -canonicalize | FileCheck %s
22

33
module {
4+
// COM: Derived from tutorial 03-matrix-multiplication.
45
tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
56
%c31_i32 = arith.constant 31 : i32
67
%cst = arith.constant dense<0.000000e+00> : tensor<64x128xf32>

third_party/intel/backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ def make_ttir(mod, metadata, opt):
224224
pm.enable_debug()
225225
passes.common.add_inliner(pm)
226226
passes.ttir.add_combine(pm)
227+
passes.common.add_cse(pm)
228+
passes.common.add_licm(pm)
227229
intel.passes.ttir.add_remove_masks(pm)
228230
if raise_block_ptr_flags['enabled']:
229231
ignore_masks = True if raise_block_ptr_flags['ignore-masks'] else False

0 commit comments

Comments
 (0)