Skip to content

[Performance] Improve the code generated by the RewriteTensorPointer pass. #1766

@mfrancepillois

Description

@mfrancepillois

When the Triton::MakeTensorPtrOp has to be rewritten by the RewriteTensorPointer pass to use "regular" memory operations, the generated code seems less performant than a code directly written using regular operations.
Indeed, the Trtion::AdvanceOp are used as anchors to generate the new memory accesses, which cause the entire code that calculates the pointers to be inside the loop, while a significant part of these instructions could be hoisted outside the loop.

For example:
The relevant section of the TritonGPU MLIR code of the 03 tutorial, looks like;

%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
%17 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%18 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%19 = arith.addi %17, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
%20 = arith.addi %18, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
%21 = tt.splat %arg3 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%22 = arith.remsi %19, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%23 = arith.muli %13, %c256_i32 : i32
%24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%25 = tt.splat %23 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%26 = arith.addi %25, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%27 = tt.splat %arg4 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%28 = arith.remsi %26, %27 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%29 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
%30 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1>
%31 = arith.muli %29, %30 : tensor<128x1xi32, #blocked1>
%32 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
%34 = tt.broadcast %31 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%35 = tt.broadcast %33 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%36 = arith.addi %34, %35 : tensor<128x64xi32, #blocked1>
%37 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%38 = tt.addptr %37, %36 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%39 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
%41 = tt.splat %arg7 : i32 -> tensor<64x1xi32, #blocked>
%42 = arith.muli %40, %41 : tensor<64x1xi32, #blocked>
%43 = tt.expand_dims %28 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked>
%44 = tt.broadcast %42 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked>
%45 = tt.broadcast %43 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked>
%46 = arith.addi %44, %45 : tensor<64x256xi32, #blocked>
%47 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked>
%48 = tt.addptr %47, %46 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked>
%49 = arith.addi %arg5, %c63_i32 : i32
%50 = arith.divsi %49, %c64_i32 : i32
%51 = arith.muli %arg7, %c64_i32 : i32
%52 = tt.splat %51 : i32 -> tensor<64x256xi32, #blocked>
%53:3 = scf.for %arg9 = %c0_i32 to %50 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %38, %arg12 = %48) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x256x!tt.ptr<f16>, #blocked>)  : i32 {
      %72 = arith.muli %arg9, %c64_i32 : i32 
      %73 = arith.subi %arg5, %72 : i32 
      %74 = tt.splat %73 : i32 -> tensor<1x64xi32, #blocked1> 
      %75 = arith.cmpi slt, %33, %74 : tensor<1x64xi32, #blocked1>
      %76 = tt.broadcast %75 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> 
      %77 = tt.load %arg11, %76, %cst_1 : tensor<128x64x!tt.ptr<f16>, #blocked1> 
      %78 = triton_gpu.local_alloc %77 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> 
      %79 = tt.splat %73 : i32 -> tensor<64x1xi32, #blocked> 
      %80 = arith.cmpi slt, %40, %79 : tensor<64x1xi32, #blocked> 
      %81 = tt.broadcast %80 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> 
      %82 = tt.load %arg12, %81, %cst_0 : tensor<64x256x!tt.ptr<f16>, #blocked>
      %83 = triton_gpu.local_alloc %82 : (tensor<64x256xf16, #blocked>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory>
      %84 = triton_gpu.local_load %78 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
      %85 = triton_gpu.local_load %83 : !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %86 = tt.dot %84, %85, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x256xf32, #mma> 
      %87 = tt.addptr %arg11, %cst_2 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> 
      %88 = tt.addptr %arg12, %52 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked> 
      scf.yield %86, %87, %88 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x256x!tt.ptr<f16>, #blocked>
    }

While the code for an equivalent of the 03 tutorial using block pointers (after forcing the block pointers to be rewritten) looks like:

%18 = arith.extsi %arg7 : i32 to i64
%19 = arith.extsi %17 : i32 to i64
%20 = arith.addi %arg5, %c63_i32 : i32
%21 = arith.divsi %20, %c64_i32 : i32 
%22:3 = scf.for %arg9 = %c0_i32 to %21 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %c0_i64, %arg12 = %c0_i64) -> (tensor<128x256xf32, #mma>, i64, i64)  : i32 {
      %43 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked> 
      %44 = tt.splat %16 : i64 -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
      %45 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
      %46 = arith.extsi %45 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
      %47 = arith.addi %44, %46 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> 
      %48 = tt.expand_dims %47 {axis = 1 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> 
      %49 = tt.splat %15 : i64 -> tensor<128x1xi64, #blocked> 
      %50 = arith.muli %48, %49 : tensor<128x1xi64, #blocked> 
      %51 = tt.broadcast %50 : tensor<128x1xi64, #blocked> -> tensor<128x64xi64, #blocked> 
      %52 = tt.addptr %43, %51 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> 
      %53 = tt.splat %arg11 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> 
      %54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> 
      %55 = arith.extsi %54 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> 
      %56 = arith.addi %53, %55 : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> 
      %57 = tt.expand_dims %56 {axis = 0 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi64, #blocked> 
      %58 = tt.broadcast %57 : tensor<1x64xi64, #blocked> -> tensor<128x64xi64, #blocked> 
      %59 = tt.addptr %52, %58 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> 
      %60 = tt.load %59 : tensor<128x64x!tt.ptr<f16>, #blocked> 
      %61 = triton_gpu.local_alloc %60 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> 
      %62 = triton_gpu.local_load %61 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> 
      %63 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked1> 
      %64 = tt.splat %arg12 : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
      %65 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
      %66 = arith.extsi %65 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> to tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
      %67 = arith.addi %64, %66 : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> 
      %68 = tt.expand_dims %67 {axis = 1 : i32} : tensor<64xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi64, #blocked1> 
      %69 = tt.splat %18 : i64 -> tensor<64x1xi64, #blocked1> 
      %70 = arith.muli %68, %69 : tensor<64x1xi64, #blocked1> 
      %71 = tt.broadcast %70 : tensor<64x1xi64, #blocked1> -> tensor<64x256xi64, #blocked1> 
      %72 = tt.addptr %63, %71 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> 
      %73 = tt.splat %19 : i64 -> tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> 
      %74 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> 
      %75 = arith.extsi %74 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> to tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> 
      %76 = arith.addi %73, %75 : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> 
      %77 = tt.expand_dims %76 {axis = 0 : i32} : tensor<256xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi64, #blocked1> 
      %78 = tt.broadcast %77 : tensor<1x256xi64, #blocked1> -> tensor<64x256xi64, #blocked1> 
      %79 = tt.addptr %72, %78 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> 
      %80 = tt.load %79 : tensor<64x256x!tt.ptr<f16>, #blocked1> 
      %81 = triton_gpu.local_alloc %80 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> 
      %82 = triton_gpu.local_load %81 : !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> 
      %83 = tt.dot %62, %82, %arg10, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x256xf32, #mma>
      %84 = arith.addi %arg11, %c64_i64 : i64
      %85 = arith.addi %arg12, %c64_i64 : i64
      scf.yield %83, %84, %85 : tensor<128x256xf32, #mma>, i64, i64
    }

The RewriteTensorPointer pass should therefore be optimized to hoist these extra instructions out of the loop.

Metadata

Metadata

Assignees

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions