-
Notifications
You must be signed in to change notification settings - Fork 76
Description
When the Triton::MakeTensorPtrOp has to be rewritten by the RewriteTensorPointer pass to use "regular" memory operations, the generated code seems less performant than a code directly written using regular operations.
Indeed, the Trtion::AdvanceOp are used as anchors to generate the new memory accesses, which cause the entire code that calculates the pointers to be inside the loop, while a significant part of these instructions could be hoisted outside the loop.
For example:
The relevant section of the TritonGPU MLIR code of the 03 tutorial, looks like;
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%17 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%18 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%19 = arith.addi %17, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%20 = arith.addi %18, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%21 = tt.splat %arg3 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%22 = arith.remsi %19, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%23 = arith.muli %13, %c256_i32 : i32
%24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%25 = tt.splat %23 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%26 = arith.addi %25, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%27 = tt.splat %arg4 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%28 = arith.remsi %26, %27 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%29 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
%30 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1>
%31 = arith.muli %29, %30 : tensor<128x1xi32, #blocked1>
%32 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
%34 = tt.broadcast %31 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%35 = tt.broadcast %33 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%36 = arith.addi %34, %35 : tensor<128x64xi32, #blocked1>
%37 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%38 = tt.addptr %37, %36 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%39 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
%41 = tt.splat %arg7 : i32 -> tensor<64x1xi32, #blocked>
%42 = arith.muli %40, %41 : tensor<64x1xi32, #blocked>
%43 = tt.expand_dims %28 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked>
%44 = tt.broadcast %42 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked>
%45 = tt.broadcast %43 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked>
%46 = arith.addi %44, %45 : tensor<64x256xi32, #blocked>
%47 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked>
%48 = tt.addptr %47, %46 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked>
%49 = arith.addi %arg5, %c63_i32 : i32
%50 = arith.divsi %49, %c64_i32 : i32
%51 = arith.muli %arg7, %c64_i32 : i32
%52 = tt.splat %51 : i32 -> tensor<64x256xi32, #blocked>
%53:3 = scf.for %arg9 = %c0_i32 to %50 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %38, %arg12 = %48) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x256x!tt.ptr<f16>, #blocked>) : i32 {
%72 = arith.muli %arg9, %c64_i32 : i32
%73 = arith.subi %arg5, %72 : i32
%74 = tt.splat %73 : i32 -> tensor<1x64xi32, #blocked1>
%75 = arith.cmpi slt, %33, %74 : tensor<1x64xi32, #blocked1>
%76 = tt.broadcast %75 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1>
%77 = tt.load %arg11, %76, %cst_1 : tensor<128x64x!tt.ptr<f16>, #blocked1>
%78 = triton_gpu.local_alloc %77 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%79 = tt.splat %73 : i32 -> tensor<64x1xi32, #blocked>
%80 = arith.cmpi slt, %40, %79 : tensor<64x1xi32, #blocked>
%81 = tt.broadcast %80 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked>
%82 = tt.load %arg12, %81, %cst_0 : tensor<64x256x!tt.ptr<f16>, #blocked>
%83 = triton_gpu.local_alloc %82 : (tensor<64x256xf16, #blocked>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory>
%84 = triton_gpu.local_load %78 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%85 = triton_gpu.local_load %83 : !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%86 = tt.dot %84, %85, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x256xf32, #mma>
%87 = tt.addptr %arg11, %cst_2 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%88 = tt.addptr %arg12, %52 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked>
scf.yield %86, %87, %88 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x256x!tt.ptr<f16>, #blocked>
}
While the code for an equivalent of the 03 tutorial using block pointers (after forcing the block pointers to be rewritten) looks like:
%18 = arith.extsi %arg7 : i32 to i64
%19 = arith.extsi %17 : i32 to i64
%20 = arith.addi %arg5, %c63_i32 : i32
%21 = arith.divsi %20, %c64_i32 : i32
%22:3 = scf.for %arg9 = %c0_i32 to %21 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %c0_i64, %arg12 = %c0_i64) -> (tensor<128x256xf32, #mma>, i64, i64) : i32 {
%43 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked>
%44 = tt.splat %16 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%45 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%46 = arith.extsi %45 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%47 = arith.addi %44, %46 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%48 = tt.expand_dims %47 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked>
%49 = tt.splat %15 : i64 -> tensor<128x1xi64, #blocked>
%50 = arith.muli %48, %49 : tensor<128x1xi64, #blocked>
%51 = tt.broadcast %50 : tensor<128x1xi64, #blocked> -> tensor<128x64xi64, #blocked>
%52 = tt.addptr %43, %51 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked>
%53 = tt.splat %arg11 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%55 = arith.extsi %54 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%56 = arith.addi %53, %55 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%57 = tt.expand_dims %56 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi64, #blocked>
%58 = tt.broadcast %57 : tensor<1x64xi64, #blocked> -> tensor<128x64xi64, #blocked>
%59 = tt.addptr %52, %58 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked>
%60 = tt.load %59 : tensor<128x64x!tt.ptr<f16>, #blocked>
%61 = triton_gpu.local_alloc %60 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%62 = triton_gpu.local_load %61 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%63 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked1>
%64 = tt.splat %arg12 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%65 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%66 = arith.extsi %65 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%67 = arith.addi %64, %66 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%68 = tt.expand_dims %67 {axis = 1 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi64, #blocked1>
%69 = tt.splat %18 : i64 -> tensor<64x1xi64, #blocked1>
%70 = arith.muli %68, %69 : tensor<64x1xi64, #blocked1>
%71 = tt.broadcast %70 : tensor<64x1xi64, #blocked1> -> tensor<64x256xi64, #blocked1>
%72 = tt.addptr %63, %71 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1>
%73 = tt.splat %19 : i64 -> tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%74 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%75 = arith.extsi %74 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> to tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%76 = arith.addi %73, %75 : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%77 = tt.expand_dims %76 {axis = 0 : i32} : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi64, #blocked1>
%78 = tt.broadcast %77 : tensor<1x256xi64, #blocked1> -> tensor<64x256xi64, #blocked1>
%79 = tt.addptr %72, %78 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1>
%80 = tt.load %79 : tensor<64x256x!tt.ptr<f16>, #blocked1>
%81 = triton_gpu.local_alloc %80 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory>
%82 = triton_gpu.local_load %81 : !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%83 = tt.dot %62, %82, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x256xf32, #mma>
%84 = arith.addi %arg11, %c64_i64 : i64
%85 = arith.addi %arg12, %c64_i64 : i64
scf.yield %83, %84, %85 : tensor<128x256xf32, #mma>, i64, i64
}
The RewriteTensorPointer pass should therefore be optimized to hoist these extra instructions out of the loop.