From c7fe682182a0a44d436db29a12f5540acb363a24 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 9 Oct 2024 16:43:52 +0000 Subject: [PATCH 01/32] Improve axis analysis to handle tt.make_tensor_ptr Signed-off-by: Tiotto, Ettore --- bin/RegisterTritonDialects.h | 5 + test/Analysis/intel/test-axis-info.mlir | 889 +++++++++++ test/lib/Analysis/CMakeLists.txt | 1 + test/lib/Analysis/intel/TestAxisInfo.cpp | 47 + third_party/intel/include/Analysis/AxisInfo.h | 215 +++ third_party/intel/lib/Analysis/AxisInfo.cpp | 1420 +++++++++++++++++ third_party/intel/lib/Analysis/CMakeLists.txt | 1 + 7 files changed, 2578 insertions(+) create mode 100644 test/Analysis/intel/test-axis-info.mlir create mode 100644 test/lib/Analysis/intel/TestAxisInfo.cpp create mode 100644 third_party/intel/include/Analysis/AxisInfo.h create mode 100644 third_party/intel/lib/Analysis/AxisInfo.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 7f0855f8cd..2c907e9f0d 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -37,6 +37,10 @@ namespace mlir { namespace test { +namespace intel { +void registerTestAxisInfoPass(); +} + void registerTestAliasPass(); void registerTestAlignmentPass(); void registerTestAllocationPass(); @@ -50,6 +54,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonPasses(); mlir::triton::gpu::registerTritonGPUPasses(); mlir::registerTritonNvidiaGPUPasses(); + mlir::test::intel::registerTestAxisInfoPass(); mlir::test::registerTestAliasPass(); mlir::test::registerTestAlignmentPass(); mlir::test::registerTestAllocationPass(); diff --git a/test/Analysis/intel/test-axis-info.mlir b/test/Analysis/intel/test-axis-info.mlir new file mode 100644 index 0000000000..95b846904c --- /dev/null +++ b/test/Analysis/intel/test-axis-info.mlir @@ -0,0 +1,889 @@ +// RUN: triton-opt %s -test-print-axis-info -split-input-file -o %t 2>&1 | FileCheck %s + +// CHECK-LABEL: @cast +tt.func @cast() { + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %cst = arith.constant 1 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %0 = arith.extsi %cst : i32 to i64 + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %cst_tensor = arith.constant dense<1> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64> + tt.return +} + +// ----- + +// CHECK-LABEL: @add +tt.func @add() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = + %2 = arith.addi %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 127 + %3 = arith.constant dense<127> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %4 = arith.addi %1, %3 : tensor<128xi32> + tt.return +} + +// ----- + +// CHECK-LABEL: @addptr +tt.func @addptr(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %cst1 = arith.constant 1 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %0 = tt.addptr %arg0, %cst1 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %1 = tt.addptr %arg1, %cst1 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = + %2 = tt.addptr %arg2, %cst1 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %3 = tt.addptr %arg3, %cst1 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = + %4 = tt.addptr %arg4, %cst1 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4 + %cst4 = arith.constant 4 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %5 = tt.addptr %arg0, %cst4 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %6 = tt.addptr %arg1, %cst4 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = + %7 = tt.addptr %arg2, %cst4 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = + %8 = tt.addptr %arg3, %cst4 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = + %9 = tt.addptr %arg4, %cst4 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = + %11 = tt.expand_dims %10 {axis = 0: i32} : tensor<128xi32> -> tensor<1x128xi32> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = + %12 = tt.broadcast %11 : tensor<1x128xi32> -> tensor<128x128xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = + %13 = tt.splat %arg0 : !tt.ptr -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = + %14 = tt.splat %arg1 : !tt.ptr -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = + %15 = tt.splat %arg2 : !tt.ptr -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = + %16 = tt.splat %arg3 : !tt.ptr -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = + %17 = tt.splat %arg4 : !tt.ptr -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [128, 1], constant_value = + %18 = tt.addptr %13, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [128, 1], constant_value = + %19 = tt.addptr %14, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [2, 16], constancy = [128, 1], constant_value = + %20 = tt.addptr %15, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [4, 16], constancy = [128, 1], constant_value = + %21 = tt.addptr %16, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [8, 16], constancy = [128, 1], constant_value = + %22 = tt.addptr %17, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + tt.return +} + +// ----- + +// CHECK-LABEL: @sub +tt.func @sub() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = + %2 = arith.subi %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129 + %3 = arith.constant dense<129> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %4 = arith.subi %3, %1 : tensor<128xi32> + tt.return +} + +// ----- + +// CHECK-LABEL: @mul +tt.func @mul(%arg0: i64 {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %2 = arith.muli %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %3 = arith.constant dense<128> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %4 = arith.muli %3, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2 + %5 = arith.constant dense<2> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256 + %6 = arith.muli %4, %5 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 4611686018427387904 + %7 = arith.constant 4611686018427387904: i64 + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = + %8 = arith.muli %arg0, %7 : i64 + tt.return +} + +// ----- + +// CHECK-LABEL: @div +tt.func @div() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %2 = arith.divsi %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %3 = arith.divui %1, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %4 = arith.constant dense<64> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = + %5 = arith.divsi %0, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %6 = arith.divsi %4, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %7 = arith.divsi %4, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66 + %8 = arith.constant dense<66> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [2], constant_value = + %9 = arith.divui %0, %8 : tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = + %10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = + %11 = arith.divsi %10, %4 : tensor<128xi32> + tt.return +} + + +// ----- + +// CHECK-LABEL: @rem +tt.func @rem() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %1 = arith.constant dense<1> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %2 = arith.remsi %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %3 = arith.remui %1, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %4 = arith.constant dense<64> : tensor<128xi32> + // CHECK-NEXT: contiguity = [64], divisibility = [64], constancy = [1], constant_value = + %5 = arith.remsi %0, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [1], constant_value = + %6 = arith.remsi %4, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66 + %7 = arith.constant dense<66> : tensor<128xi32> + // CHECK-NEXT: contiguity = [2], divisibility = [2], constancy = [1], constant_value = + %8 = arith.remui %0, %7 : tensor<128xi32> + tt.return +} + +// ----- + +// CHECK-LABEL: @expanddims +tt.func @expanddims() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2 + %1 = arith.constant dense<2> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = + %2 = arith.muli %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [2, 2], constancy = [1, 1], constant_value = + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + tt.return +} + +// ----- + +// CHECK-LABEL: @broadcast +tt.func @broadcast() { + // CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %0 = arith.constant dense<64> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 1], constant_value = 64 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 128], constant_value = 64 + %2 = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> + tt.return +} + +// ----- + +// CHECK-LABEL: @splat +tt.func @splat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x128x!tt.ptr> + tt.return +} + +// ----- + +// CHECK-LABEL: @cmp_all_contiguous +tt.func @cmp_all_contiguous() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %1 = arith.constant dense<0> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %3 = arith.cmpi ne, %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %4 = arith.cmpi slt, %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %5 = arith.cmpi sle, %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %6 = arith.cmpi sge, %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %7 = arith.cmpi sgt, %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %8 = arith.cmpi eq, %1, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %9 = arith.cmpi ne, %1, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %10 = arith.cmpi slt, %1, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %11 = arith.cmpi sle, %1, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %12 = arith.cmpi sge, %1, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %13 = arith.cmpi sgt, %1, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 + %14 = arith.constant dense<8> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %15 = arith.cmpi sgt, %14, %0 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 + %16 = arith.cmpi sgt, %14, %1 : tensor<128xi32> + tt.return +} + +// CHECK-LABEL: @cmp_partial_contiguous +tt.func @cmp_partial_contiguous() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 + %1 = arith.constant dense<8> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [32], constancy = [128], constant_value = 32 + %3 = arith.constant dense<32> : tensor<128xi32> + // CHECK-NEXT: contiguity = [32], divisibility = [32], constancy = [1], constant_value = + %4 = arith.remsi %0, %3 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %5 = arith.cmpi eq, %4, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %6 = arith.cmpi ne, %4, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %7 = arith.cmpi slt, %4, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %8 = arith.cmpi sle, %4, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %9 = arith.cmpi sge, %4, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %10 = arith.cmpi sgt, %4, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %11 = arith.cmpi eq, %1, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %12 = arith.cmpi ne, %1, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %13 = arith.cmpi slt, %1, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %14 = arith.cmpi sle, %1, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %15 = arith.cmpi sge, %1, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %16 = arith.cmpi sgt, %1, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = 48 + %17 = arith.constant dense<48> : tensor<128xi32> + // CHECK-NEXT: contiguity = [16], divisibility = [16], constancy = [1], constant_value = + %18 = arith.remsi %0, %17 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %19 = arith.cmpi eq, %18, %3 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %20 = arith.cmpi ne, %18, %3 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = + %21 = arith.cmpi slt, %18, %3 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %22 = arith.cmpi sle, %18, %3 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = + %23 = arith.cmpi sge, %18, %3 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %24 = arith.cmpi sgt, %18, %3 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %25 = arith.cmpi eq, %3, %18 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %26 = arith.cmpi ne, %3, %18 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %27 = arith.cmpi slt, %3, %18 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = + %28 = arith.cmpi sle, %3, %18 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %29 = arith.cmpi sge, %3, %18 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = + tt.return +} + +// ----- + +// CHECK-LABEL: @logic +tt.func @logic() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 + %1 = arith.constant dense<64> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = + %2 = arith.divsi %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 + %3 = arith.constant dense<8> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %4 = arith.divsi %0, %3 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %5 = arith.andi %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %6 = arith.ori %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %7 = arith.xori %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %8 = arith.andi %2, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %9 = arith.ori %2, %4 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = + %10 = arith.xori %2, %4 : tensor<128xi32> + tt.return +} + +// ----- + +// CHECK-LABEL: @select +tt.func @select(%arg0 : i1, %arg1 : tensor<4xi1>) { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %1 = arith.constant dense<0> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = + %3 = arith.cmpi slt, %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 + %4 = arith.constant 0 : i1 + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %7 = tt.splat %4 : i1 -> tensor<128xi1> + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %5 = arith.select %4, %3, %7 : tensor<128xi1> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %8 = arith.select %7, %3, %2 : tensor<128xi1>, tensor<128xi1> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = + %9 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi1> -> tensor<128x1xi1> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 1], constant_value = + %10 = tt.expand_dims %3 {axis = 1 : i32} : tensor<128xi1> -> tensor<128x1xi1> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = + %11 = arith.select %arg0, %9, %10 : tensor<128x1xi1> + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [4], constant_value = 4 + %cst = arith.constant dense<4> : tensor<4xi32> + // CHECK-NEXT: contiguity = [4], divisibility = [1073741824], constancy = [1], constant_value = + %12 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %13 = arith.muli %12, %cst : tensor<4xi32> + // CHECK-NEXT: contiguity = [4], divisibility = [16], constancy = [1], constant_value = + %14 = tt.make_range {end = 20 : i32, start = 16 : i32} : tensor<4xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %15 = arith.select %arg1, %12, %13 : tensor<4xi1>, tensor<4xi32> + tt.return +} + +// ----- + +tt.func @shift(%arg0: i32 {tt.divisibility = 4 : i32}) { + // CHECK: contiguity = [1], divisibility = [4], constancy = [128], constant_value = + %s = tt.splat %arg0 : i32 -> tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 + %1 = arith.constant dense<8> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 + %2 = arith.constant dense<4> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [1], constant_value = + %3 = arith.shli %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %4 = arith.shrsi %0, %2 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 + %5 = arith.shli %1, %2 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = + %6 = arith.shli %1, %s : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %7 = arith.shrsi %0, %s : tensor<128xi32> + tt.return +} + +// ----- + +tt.func @max_min() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = + %1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = + %2 = arith.maxsi %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = + %3 = arith.minsi %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 + %4 = arith.constant dense<8> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 + %5 = arith.constant dense<4> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8 + %6 = arith.maxsi %4, %5 : tensor<128xi32> + tt.return +} + +// ----- + +// CHECK-LABEL: @if +tt.func @if(%i1 : i1) { + // CHECK: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 32], constant_value = 64 + %cst_64 = arith.constant dense<64> : tensor<128x32xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1 + %cst_1 = arith.constant dense<1> : tensor<128x32xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 32], constant_value = 64 + %a = arith.muli %cst_64, %cst_1 : tensor<128x32xi32> + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = + %ret = scf.if %i1 -> tensor<128x32xi32> { + scf.yield %a : tensor<128x32xi32> + } else { + scf.yield %cst_1 : tensor<128x32xi32> + } + tt.return +} + +// ----- + +// CHECK-LABEL: @for +tt.func @for() { + // CHECK: contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0 + %a_init = arith.constant dense<0> : tensor<128x32xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1 + %b_init = arith.constant dense<1> : tensor<128x32xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 + %c_init = arith.constant dense<4> : tensor<128x32xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128 + %ub = arith.constant 128 : index + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 + %lb = arith.constant 0 : index + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16 + %step = arith.constant 16 : index + %a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) { + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = + %t = arith.index_cast %iv : index to i32 + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = + // CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 + scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32> + } + tt.return +} + +// ----- + +// CHECK-LABEL: @for_dynamic +tt.func @for_dynamic(%lb: index {tt.divisibility = 16 : i32}, %step: index {tt.divisibility = 8 : i32}, %ub: index) { + scf.for %iv = %lb to %ub step %step { + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = + %t = arith.index_cast %iv : index to i32 + } + tt.return +} + +// ----- + +// CHECK-LABEL: @for_if +tt.func @for_if(%i1: i1, %arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 + %c0_i32 = arith.constant 0 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %c1_i32 = arith.constant 1 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = 10 + %c10_i32 = arith.constant 10 : i32 + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 64], constant_value = 64 + %cst = arith.constant dense<64> : tensor<128x64xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = + %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> + %2 = scf.for %arg9 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg1 = %1) -> (tensor<128x64x!tt.ptr>): i32 { + // CHECK: scf.if + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = + %3 = scf.if %i1 -> (tensor<128x64x!tt.ptr>) { + scf.yield %arg1 : tensor<128x64x!tt.ptr> + } else { + scf.yield %arg1 : tensor<128x64x!tt.ptr> + } + // CHECK: tt.addptr + // CHECK-SAME: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = + %4 = tt.addptr %3, %cst : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + // CHECK: scf.for + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = + scf.yield %1 : tensor<128x64x!tt.ptr> + } + tt.return +} + +// ----- + +// CHECK-LABEL: @for_if_for +tt.func @for_if_for(%i1: i1, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}) { + // CHECK: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 + %c0_i32 = arith.constant 0 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %c1_i32 = arith.constant 1 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = 10 + %c10_i32 = arith.constant 10 : i32 + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 64], constant_value = 64 + %cst = arith.constant dense<64> : tensor<128x64xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = + %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr> + // CHECK: scf.for + // CHECK: contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = + // CHECK: scf.if + // CHECK: contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = + // CHECK: tt.addptr + // CHECK-SAME: contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = + // CHECK: scf.for + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = + %3 = scf.for %arg9 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg2 = %1) -> (tensor<128x64x!tt.ptr>) : i32 { + %4 = scf.if %i1 -> (tensor<128x64x!tt.ptr>) { + %5 = scf.for %arg10 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg3 = %2) -> (tensor<128x64x!tt.ptr>) : i32 { + scf.yield %arg3 : tensor<128x64x!tt.ptr> + } + scf.yield %5 : tensor<128x64x!tt.ptr> + } else { + scf.yield %arg2 : tensor<128x64x!tt.ptr> + } + %6 = tt.addptr %4, %cst : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + scf.yield %1 : tensor<128x64x!tt.ptr> + } + tt.return +} + +// ----- + +// CHECK-LABEL: @permute_2d +tt.func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1 + %cst = arith.constant dense : tensor<128x128xi1> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = + %3 = tt.splat %arg1 : i32 -> tensor<128x1xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + %4 = arith.muli %2, %3 : tensor<128x1xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = + %5 = tt.splat %arg0 : !tt.ptr -> tensor<128x1x!tt.ptr> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + %6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = + %7 = tt.expand_dims %1 {axis = 0 : i32}: tensor<128xi32> -> tensor<1x128xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = + %8 = tt.broadcast %6 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = + %9 = tt.broadcast %7 : tensor<1x128xi32> -> tensor<128x128xi32> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [4, 16], constancy = [1, 1], constant_value = + %10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = + %11 = tt.expand_dims %0 {axis = 1 : i32}: tensor<128xi32> -> tensor<128x1xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = + %12 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> + // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = + %13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = + %14 = tt.expand_dims %1 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = + %15 = tt.splat %arg3 : i32 -> tensor<1x128xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + %16 = arith.muli %14, %15 : tensor<1x128xi32> + // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 128], constant_value = + %17 = tt.broadcast %13 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = + %18 = tt.broadcast %16 : tensor<1x128xi32> -> tensor<128x128xi32> + // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = + %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = + %20 = tt.load %10, %cst, %cst_0 : tensor<128x128x!tt.ptr> + tt.store %19, %20, %cst : tensor<128x128x!tt.ptr> + tt.return +} + +// ----- + +// CHECK-LABEL: @load_constancy +tt.func @load_constancy(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 1 : i32}) { + // CHECK: divisibility = [16] + %sixteen = arith.constant dense<16> : tensor<1024xi32> + // CHECK-NEXT: divisibility = [8] + %eight = arith.constant dense<8> : tensor<1024xi32> + // CHECK-NEXT: contiguity = [1024], divisibility = [1073741824], constancy = [1] + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // CHECK-NEXT: constancy = [16] + %2 = arith.divsi %1, %sixteen : tensor<1024xi32> + // CHECK-NEXT: constancy = [1024] + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + // CHECK-NEXT: constancy = [1024] + %4 = tt.splat %arg1 : i32 -> tensor<1024xi32> + // CHECK-NEXT: constancy = [8] + %5 = arith.divsi %1, %eight : tensor<1024xi32> + // CHECK-NEXT: constancy = [8] + %6 = arith.cmpi slt, %5, %4 : tensor<1024xi32> + // CHECK-NEXT: constancy = [16] + %7 = tt.addptr %3, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // CHECK-NEXT: constancy = [16] + %8 = tt.load %7 : tensor<1024x!tt.ptr> + // CHECK-NEXT: constancy = [8] + %9 = tt.load %7, %6 : tensor<1024x!tt.ptr> + tt.return +} + +// ----- + +// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer. +// CHECK-LABEL: @store_constant_align +tt.func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %pid = tt.get_program_id x : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128 + %c128_i32 = arith.constant 128 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = + %1 = arith.muli %pid, %c128_i32 : i32 + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = + %3 = tt.splat %1 : i32 -> tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [128], constancy = [1], constant_value = + %4 = arith.addi %3, %2 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = + %5 = tt.splat %addr : !tt.ptr -> tensor<128x!tt.ptr> + // CHECK-NEXT: contiguity = [128], divisibility = [16], constancy = [1], constant_value = + %6 = tt.addptr %5, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = + %9 = tt.splat %n : i32 -> tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = + %mask = arith.cmpi slt, %4, %9 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %cst = arith.constant dense<0.0> : tensor<128xf32> + tt.store %5, %cst, %mask : tensor<128x!tt.ptr> + tt.return +} + +// ----- + +// This IR is dumped from vecadd test. +// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask. +// CHECK-LABEL: @vecadd_mask_align_16 +tt.func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %1 : i32 -> tensor<64xi32> + %4 = arith.addi %3, %2 : tensor<64xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %9 = tt.splat %n_elements : i32 -> tensor<64xi32> + // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [16], constant_value = + %mask = arith.cmpi slt, %4, %9 : tensor<64xi32> + %11 = tt.load %6, %mask : tensor<64x!tt.ptr> + %12 = tt.load %8, %mask : tensor<64x!tt.ptr> + %13 = arith.addf %11, %12 : tensor<64xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<64x!tt.ptr> + // CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %15, %13, %mask : tensor<64x!tt.ptr> + tt.return +} + +// ----- + +// This IR is dumped from vecadd test. +// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default. +// CHECK-LABEL: @vecadd_mask_align_1 +tt.func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %1 : i32 -> tensor<64xi32> + %4 = arith.addi %3, %2 : tensor<64xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %9 = tt.splat %n_elements : i32 -> tensor<64xi32> + // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %10 = arith.cmpi slt, %4, %9 : tensor<64xi32> + %11 = tt.load %6, %10 : tensor<64x!tt.ptr> + %12 = tt.load %8, %10 : tensor<64x!tt.ptr> + %13 = arith.addf %11, %12 : tensor<64xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<64x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %15, %13, %10 : tensor<64x!tt.ptr> + tt.return +} + +// ----- + +module { + +// We don't use function cloning here, so the alignment info is the gcd of all call sites. +// CHECK-LABEL: @addptr_hints +tt.func @addptr_hints(%arg0: !tt.ptr) { + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %cst1 = arith.constant 1 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %1 = tt.addptr %arg0, %cst1 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4 + %cst4 = arith.constant 4 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %2 = tt.addptr %arg0, %cst4 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16 + %cst16 = arith.constant 16 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %3 = tt.addptr %arg0, %cst4 : !tt.ptr, i32 + tt.return +} + +// CHECK-LABEL: @kernel_div16 +tt.func @kernel_div16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + tt.call @addptr_hints(%arg0) : (!tt.ptr) -> () + tt.return +} + +// CHECK-LABEL: @kernel_div8 +tt.func @kernel_div8(%arg0: !tt.ptr {tt.divisibility = 8 : i32}) { + tt.call @addptr_hints(%arg0) : (!tt.ptr) -> () + tt.return +} + +// CHECK-LABEL: @kernel_div4 +tt.func @kernel_div4(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + tt.call @addptr_hints(%arg0) : (!tt.ptr) -> () + tt.return +} + +} + +// ----- + +module { + +// We don't use function cloning here, so the alignment info is the gcd of all call sites. +// CHECK-LABEL: @mul +tt.func @mul(%arg0: i32) { + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %cst1 = arith.constant 1 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %1 = arith.muli %arg0, %cst1 : i32 + tt.return +} + +// CHECK-LABEL: @bar +tt.func @bar(%arg0: i32) { + tt.call @mul(%arg0) : (i32) -> () + tt.return +} + +// CHECK-LABEL: @foo +tt.func @foo(%arg0: i32) { + tt.call @mul(%arg0) : (i32) -> () + tt.return +} + +// CHECK-LABEL: @call_graph +tt.func @call_graph(%arg0: i32) { + // CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 12 + %cst12 = arith.constant 12 : i32 + // CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %0 = arith.muli %arg0, %cst12 : i32 + tt.call @foo(%0) : (i32) -> () + // CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = 8 + %cst8 = arith.constant 8 : i32 + // CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = + %1 = arith.muli %arg0, %cst8 : i32 + tt.call @bar(%1) : (i32) -> () + tt.return +} + +} + +// ----- + +// CHECK-LABEL: @tensor_ptr +tt.func @tensor_ptr(%arg0: !tt.ptr, 1>) { + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = + %0 = tt.load %arg0 : !tt.ptr, 1> + tt.return +} + + +// ----- + +// CHECK-LABEL: @chained_for +tt.func public @chained_for(%8: tensor<128x64x!tt.ptr> {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> + // CHECK: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16 + %c16_i32 = arith.constant 16 : i32 + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %c1_i32 = arith.constant 1 : i32 + // CHECK: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 + %c0_i32 = arith.constant 0 : i32 + // CHECK: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 64], constant_value = 64 + %cst_0 = arith.constant dense<64> : tensor<128x64xi32> + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + %9 = scf.for %arg7 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg8 = %8) -> (tensor<128x64x!tt.ptr>) : i32 { + %11 = tt.addptr %arg8, %cst_0 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + scf.yield %11 : tensor<128x64x!tt.ptr> + } + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + %10 = scf.for %arg7 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg8 = %9) -> (tensor<128x64x!tt.ptr>) : i32 { + tt.store %arg8, %cst : tensor<128x64x!tt.ptr> + %11 = tt.addptr %arg8, %cst_0 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + scf.yield %11 : tensor<128x64x!tt.ptr> + } + tt.return +} + +// ----- + +// CHECK-LABEL: @int_min_does_not_underflow_in_analysis +tt.func @int_min_does_not_underflow_in_analysis() -> i64 { + // CHECK: divisibility = [4611686018427387904] + %int_min = arith.constant -9223372036854775808 : i64 + tt.return %int_min : i64 +} + +// ----- + +// CHECK-LABEL: @make_tensor_ptr +tt.func public @make_tensor_ptr(%arg0: !tt.ptr, %arg1: !tt.ptr {tt.divisibility = 32 : i32}, %arg2: i64 {tt.divisibility = 16 : i32}) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + // CHECK: %0 = tt.make_tensor_ptr %arg0, {{.*}} => contiguity = [128, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + // CHECK-NEXT: %1 = tt.make_tensor_ptr %arg1, {{.*}} => contiguity = [32, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = + %1 = tt.make_tensor_ptr %arg1, [%c32_i64, %c32_i64], [%c1_i64, %arg2], [%c0_i32, %c0_i32] {order = array} : > + tt.return +} diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt index da7bc3f78a..75f785ce24 100644 --- a/test/lib/Analysis/CMakeLists.txt +++ b/test/lib/Analysis/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(TritonTestAnalysis + intel/TestAxisInfo.cpp TestAlias.cpp TestAxisInfo.cpp TestAllocation.cpp diff --git a/test/lib/Analysis/intel/TestAxisInfo.cpp b/test/lib/Analysis/intel/TestAxisInfo.cpp new file mode 100644 index 0000000000..68c14999fb --- /dev/null +++ b/test/lib/Analysis/intel/TestAxisInfo.cpp @@ -0,0 +1,47 @@ +#include "intel/include/Analysis/AxisInfo.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +struct TestAxisInfoPass + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass); + + StringRef getArgument() const final { return "test-print-axis-info"; } + StringRef getDescription() const final { + return "print the result of the axis analysis analysis pass"; + } + + void runOnOperation() override { + Operation *operation = getOperation(); + ModuleOp moduleOp = cast(operation); + intel::ModuleAxisInfoAnalysis moduleAxisInfoAnalysis(moduleOp); + moduleOp.walk([&](FuncOp funcOp) { + auto &os = llvm::errs(); + auto opName = SymbolTable::getSymbolName(funcOp).getValue().str(); + os << "@" << opName << "\n"; + funcOp.walk([&](Operation *op) { + if (op->getNumResults() < 1) + return; + for (Value result : op->getResults()) { + result.print(os); + os << " => "; + auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result); + if (axisInfo) + axisInfo->print(os); + os << "\n"; + } + }); + }); + } +}; + +} // namespace + +namespace mlir::test::intel { +void registerTestAxisInfoPass() { PassRegistration(); } +} // namespace mlir::test::intel diff --git a/third_party/intel/include/Analysis/AxisInfo.h b/third_party/intel/include/Analysis/AxisInfo.h new file mode 100644 index 0000000000..7dc9f09422 --- /dev/null +++ b/third_party/intel/include/Analysis/AxisInfo.h @@ -0,0 +1,215 @@ +#ifndef TRITON_ANALYSIS_AXISINFO_H +#define TRITON_ANALYSIS_AXISINFO_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include +#include + +namespace mlir::triton::intel { + +//===----------------------------------------------------------------------===// +// AxisInfo +//===----------------------------------------------------------------------===// + +/// This lattice value represents known information on the axes of a lattice. +class AxisInfo { +public: + typedef SmallVector DimVectorT; + +public: + AxisInfo() : AxisInfo({}, {}, {}) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy) + : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy, + std::optional constantValue) + : contiguity(contiguity), divisibility(divisibility), + constancy(constancy), constantValue(constantValue) { + assert(divisibility.size() == contiguity.size()); + assert(constancy.size() == contiguity.size()); + } + + // contiguity[d] is the length of the shortest sequence of contiguous integers + // along dimension d. + // + // If we have an array of N elements with a contiguity value C, then the array + // can be divided into a list of N/C sequences of C contiguous elements. + // Since we have N = 2^k, C must be a power of two. + // + // For example, the 2D array + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has contiguity [1, 4], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27], + // [18, 22, 26, 30], + // [19, 23, 27, 31]] + // + // has contiguity [2, 1]. + int64_t getContiguity(size_t dim) const { return contiguity[dim]; } + const DimVectorT &getContiguity() const { return contiguity; } + + // divisibility[d] is the largest power of two that divides the first element + // of all groups of length contiguity[d] along dimension d. + // + // For example, + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has divisibility [1, 2], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27]] + // + // has divisibility [4, 1]. + // + // On the other hand, + // + // [0, 1, 2, 0, 4, 5, 6, 7] + // + // has divisibility 1 because its contiguity is 1. + int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } + const DimVectorT &getDivisibility() const { return divisibility; } + + // constancy[d] is the length of the shortest sequence of repeating integers + // along dimension d. + // + // This is particularly useful to infer the contiguity of operations (e.g. + // add) involving a constant. + // + // If we have an array of N elements, with a constancy value C, then the array + // can be divided into a list of N/C sequences of C elements with the same + // value. Since we have N = 2^k, C must be a power of two. + // + // For example + // + // [[8, 8, 8, 8, 12, 12, 12, 12], + // [16, 16, 16, 16, 20, 20, 20, 20]] + // + // has constancy [1, 4]. + int64_t getConstancy(size_t dim) const { return constancy[dim]; } + const DimVectorT &getConstancy() const { return constancy; } + + int getRank() const { return contiguity.size(); } + + std::optional getConstantValue() const { return constantValue; } + + template + static void + initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, + DimVectorT *divisibility, DimVectorT *constancy); + + bool operator==(const AxisInfo &other) const { + return contiguity == other.contiguity && + divisibility == other.divisibility && constancy == other.constancy && + constantValue == other.constantValue; + } + + static AxisInfo getPessimisticValueState(Value value); + + // The gcd of both arguments for each dimension + static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); + + void print(raw_ostream &os) const { + auto print = [&](StringRef name, DimVectorT vec) { + os << name << " = ["; + llvm::interleaveComma(vec, os); + os << "]"; + }; + print("contiguity", contiguity); + print(", divisibility", divisibility); + print(", constancy", constancy); + os << ", constant_value = "; + if (constantValue) + os << *constantValue; + else + os << ""; + } + +private: + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + + // The constant value of the lattice if we can infer it. + std::optional constantValue; +}; + +// Module level axis info analysis based on the call graph, assuming that we do +// not have recursive functions. +// +// Since each function will be called multiple times, we need to calculate the +// axis info based on the axis info of all the callers. In the future, we can +// perform optimization using function cloning so that each call site will have +// unique axis info. +using AxisInfoMapT = DenseMap; +class ModuleAxisInfoAnalysis : public CallGraph { +public: + explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, AxisInfoMapT{}); + }); + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + SymbolTableCollection symbolTable; + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = dyn_cast( + callOp.resolveCallableInTable(&symbolTable)); + update(callOp, callee); + }); + } + } + + AxisInfo *getAxisInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } + + unsigned getPtrContiguity(Value ptr); + unsigned getPtrAlignment(Value ptr); + unsigned getMaskAlignment(Value mask); + +private: + void initialize(FunctionOpInterface funcOp); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; + +} // namespace mlir::triton::intel + +#endif diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp new file mode 100644 index 0000000000..4c63c1f25f --- /dev/null +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -0,0 +1,1420 @@ +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include "third_party/intel/include/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "axis-info" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton::intel { +namespace { + +int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) { + // Base Case + if (a == 0) { + *x = 0; + *y = 1; + return b; + } + int64_t x1, y1; // To store results of recursive call + int64_t gcd = gcdImpl(b % a, a, &x1, &y1); + // Update x and y using results of + // recursive call + *x = y1 - (b / a) * x1; + *y = x1; + return gcd; +} + +int64_t gcd(int64_t a, int64_t b) { + if (a == 0) + return b; + if (b == 0) + return a; + int64_t x, y; + return gcdImpl(a, b, &x, &y); +} + +constexpr int log2Int(int64_t num) { + return (num > 1) ? 1 + log2Int(num / 2) : 0; +} + +// If lhs * rhs overflows, return max value possible value for the type +int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { + int64_t maxDivisor = highestPowOf2Divisor(0); + if (lhs > maxDivisor / rhs) + return maxDivisor; + return lhs * rhs; +} + +class AxisInfoVisitor { +public: + AxisInfoVisitor() = default; + virtual ~AxisInfoVisitor() = default; + + static bool isContiguousDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getContiguity(dim) == shape[dim]; + } + + static bool isConstantDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getConstancy(dim) == shape[dim]; + } + + virtual AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) = 0; + + virtual bool match(Operation *op) = 0; +}; + +// Base class for all operations +template class AxisInfoVisitorImpl : public AxisInfoVisitor { +public: + using AxisInfoVisitor::AxisInfoVisitor; + + AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) final { + return getAxisInfo(cast(op), operands); + } + + bool match(Operation *op) final { return isa(op); } + + virtual AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) = 0; +}; + +// Binary operations +template +class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + assert(operands.size() == 2 && "Expected two operands"); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); + for (auto d = 0; d < rank; ++d) { + if (constantValue.has_value()) { + contiguity.push_back(1); + constancy.push_back( + std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + highestPowOf2Divisor(constantValue.value())); + } else { + contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); + constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); + divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); + } + } + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +protected: + virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getDivisibility(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) { + return {}; + } +}; + +class AxisInfoVisitorList { +public: + template > + void append() { + (visitors.emplace_back(std::make_unique()), ...); + } + + AxisInfo apply(Operation *op, + ArrayRef *> operands) { + for (auto &visitor : visitors) + if (visitor->match(op)) + return visitor->getAxisInfo(op, operands); + return AxisInfo(); + } + +private: + std::vector> visitors; +}; + +class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + AxisInfoVisitorList visitors; + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, lattice->join( + AxisInfo::getPessimisticValueState(lattice->getAnchor()))); + } + + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef *> argLattices, + unsigned firstIndex) override { + if (auto forOp = dyn_cast(op)) { + visitForOpInductionVar(forOp, argLattices); + } else { + setAllToEntryStates(argLattices.take_front(firstIndex)); + setAllToEntryStates(argLattices.drop_front( + firstIndex + successor.getSuccessorInputs().size())); + } + } + +public: + AxisInfoAnalysis(DataFlowSolver &solver); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + using FuncAxisInfoMapT = DenseMap; + + LogicalResult + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; + void + visitForOpInductionVar(scf::ForOp op, + ArrayRef *> argLattices); +}; + +template +class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + return operands[0]->getValue(); + } +}; + +class MakeRangeOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::MakeRangeOp op, + ArrayRef *> operands) override { + auto start = op.getStart(); + auto end = op.getEnd(); + return AxisInfo(/*contiguity=*/{end - start}, + /*divisibility=*/{highestPowOf2Divisor(start)}, + /*constancy=*/{1}); + } +}; + +template +class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto intAttr = dyn_cast(op.getValue()); + auto boolAttr = dyn_cast(op.getValue()); + if (intAttr || boolAttr) { + int64_t value{}; + if (intAttr) + value = intAttr.getValue().getZExtValue(); + else + value = boolAttr.getValue() ? 1 : 0; + return AxisInfo(/*contiguity=*/{1}, + /*divisibility=*/{highestPowOf2Divisor(value)}, + /*constancy=*/{1}, + /*knownConstantValue=*/{value}); + } + // TODO: generalize to dense attr + auto splatAttr = dyn_cast(op.getValue()); + if (splatAttr && splatAttr.getElementType().isIntOrIndex()) { + int64_t value = splatAttr.template getSplatValue().getZExtValue(); + TensorType ty = cast(splatAttr.getType()); + return AxisInfo( + /*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1), + /*divisibility=*/ + AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)), + /*constancy=*/ + AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()), + /*knownConstantValue=*/{value}); + } + return AxisInfo(); + } +}; + +template +class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)), + gcd(lhs.getContiguity(dim), rhs.getConstancy(dim))); + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs) + // rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs) + // lhs + rhs = k * d_lhs + p * d_rhs = (k * k' + p * p') * gcd(d_lhs, d_rhs) + auto rhsDivisibility = rhs.getDivisibility(dim); + if constexpr (std::is_same_v) { + // %ptr = addptr %lhs, %rhs + // is equivalent to + // %0 = mul %rhs, %elemSize + // %ptr = add %lhs, %0 + // The result will still be contiguous in terms of elements but not bytes + // For example: + // addptr [16] : !ptr, [0, 1, 2, 3] : i32 -> !ptr + // returns: + // [16, 20, 24, 28] : !ptr + // with element locations: + // [4, 5, 6, 7] + // It is "strided contiguous" with a divisilibity of 16 bytes + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + rhsDivisibility = multiplyDivisor(rhs.getDivisibility(dim), elemSize); + } + return gcd(lhs.getDivisibility(dim), rhsDivisibility); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + return {lhs.getConstantValue().value() + + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() - + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + auto rhsValue = rhs.getConstantValue().value() * elemSize; + return {lhs.getConstantValue().value() + rhsValue}; + } + } + return {}; + } +}; + +class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + // lhs * 1 = lhs + auto lhsContiguity = + rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1 + ? lhs.getContiguity(dim) + : 1; + // 1 * rhs = rhs + auto rhsContiguity = + lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1 + ? rhs.getContiguity(dim) + : 1; + return std::max(lhsContiguity, rhsContiguity); + } + + int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && + !(rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto rhsDivisibility = rhs.getDivisibility(dim); + if (rhs.getContiguity(dim) > 1 && + !(lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + rhsDivisibility = 1; + } + return multiplyDivisor(lhsDivisibility, rhsDivisibility); + } + + std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() * rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs / 1 = lhs + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? lhs.getContiguity(dim) + : 1; + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + Type ptrTy = op.getType(); + auto resTy = + isTensorPointerType(ptrTy) + ? cast(cast(ptrTy).getPointeeType()) + : dyn_cast(ptrTy); + + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // Case 1: both lhs and rhs are constants. + auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + // Case 2: lhs contiguous, rhs constant. + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p), + // ..., (d_lhs * k + n) / (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // the minimal constancy is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual constancy. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + constancy = std::max(constancy, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return constancy; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // Case 1: lhs is 0 + if (lhs.getConstantValue().has_value() && + lhs.getConstantValue().value() == 0) + return lhs.getDivisibility(dim); + // Case 2: rhs is 1 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return lhs.getDivisibility(dim); + // otherwise: return 1 + return 1; + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() / rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + Type ptrTy = op.getType(); + auto resTy = + isTensorPointerType(ptrTy) + ? cast(cast(ptrTy).getPointeeType()) + : dyn_cast(ptrTy); + + if (!resTy) + return BinaryOpVisitorImpl::getContiguity(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + int64_t contiguity = 1; + // lhs contiguous, rhs constant + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p), + // ..., (d_lhs * k + n) % (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // The minimal contiguity is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual contiguity. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + contiguity = std::max(contiguity, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return contiguity; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k'' + // rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p'' + // lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r + // r must be divisible by gcd(d_lhs, d_rhs) + return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)); + }; + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + Type ptrTy = op.getType(); + auto resTy = + isTensorPointerType(ptrTy) + ? cast(cast(ptrTy).getPointeeType()) + : dyn_cast(ptrTy); + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // lhs % 1 = 0 + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? shape[dim] + : gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() % rhs.getConstantValue().value()}; + else if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return {0}; + return {}; + } +}; + +class SplatOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::SplatOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + TensorType retTy = cast(_retTy); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(opInfo.getDivisibility(0)); + constancy.push_back(retTy.getShape()[d]); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::LoadOp op, + ArrayRef *> operands) override { + // If pointers and mask both have constancy properties, those properties + // will also extend to output. + AxisInfo ptrInfo = operands[0]->getValue(); + std::optional maskInfo; + if (operands.size() > 1) { + maskInfo = operands[1]->getValue(); + } + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + + for (int d = 0; d < ptrInfo.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(1); + constancy.push_back( + gcd(ptrInfo.getConstancy(d), + maskInfo.has_value() ? maskInfo->getConstancy(d) : 0)); + } + + return AxisInfo(contiguity, divisibility, constancy); + } +}; + +class ExpandDimsOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::ExpandDimsOp op, + ArrayRef *> operands) override { + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); + AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); + AxisInfo::DimVectorT constancy = opInfo.getConstancy(); + int64_t newDivisibility = 1; + if (opInfo.getConstantValue().has_value()) { + // The tensor is constant, same as ConstantOpAxisInfoVisitor + newDivisibility = highestPowOf2Divisor(opInfo.getConstantValue().value()); + } else if (opInfo.getRank()) { + // Otherwise, calculate the GCD as the new divisibility + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + newDivisibility = + opInfo.getContiguity(0) > 1 ? 1 : opInfo.getDivisibility(0); + for (int d = 1; d < opInfo.getRank(); ++d) { + newDivisibility = + gcd(newDivisibility, + opInfo.getContiguity(d) > 1 ? 1 : opInfo.getDivisibility(d)); + } + } + contiguity.insert(contiguity.begin() + op.getAxis(), 1); + divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility); + constancy.insert(constancy.begin() + op.getAxis(), 1); + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class BroadcastOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::BroadcastOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = cast(_retTy); + TensorType opTy = cast(_opTy); + ArrayRef retShape = retTy.getShape(); + ArrayRef opShape = opTy.getShape(); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); + divisibility.push_back(opInfo.getDivisibility(d)); + constancy.push_back(opShape[d] == 1 ? retShape[d] + : opInfo.getConstancy(d)); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +template +class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + Type ty = op.getType(); + auto resTy = + isTensorPointerType(ty) + ? cast(cast(ty).getPointeeType()) + : dyn_cast(ty); + if (!resTy) + return AxisInfo(); + auto shape = resTy.getShape(); + short rank = resTy.getRank(); + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + for (short d = 0; d < rank; ++d) { + int64_t constHint = 1; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + constHint = lhsInfo.getConstancy(d); + constantValue = + compare(getPredicate(op), lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value()) + ? 1 + : 0; + } else { + // Case 1: lhs and rhs are both partial constants + constHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)); + if ((gtPredicate(getPredicate(op)) || lePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(lhsInfo, shape, d)) { + // Case 2: lhs all constant, rhs all contiguous + // NOTE: + // lhs: 4 4 4 4 + // rhs: 4 5 6 7 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs lt rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 1, 1, 1 + // lhs ge rhs: 1, 0, 0, 0 + // lhs gt rhs: 0, 0, 0, 0 + constHint = std::max(constHint, gcd(rhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } else if ((ltPredicate(getPredicate(op)) || + gePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(rhsInfo, shape, d)) { + // Case 3: lhs all contiguous, rhs all constant + // NOTE + // lhs: 4 5 6 7 + // rhs: 4 4 4 4 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 0, 0, 0 + // lhs lt rhs: 0, 0, 0, 0 + // lhs gt rhs: 0, 1, 1, 1 + // lhs ge rhs: 1, 1, 1, 1 + constHint = std::max(constHint, gcd(lhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } + } + + constancy.push_back(constHint); + divisibility.push_back(1); + contiguity.push_back(1); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +private: + static arith::CmpIPredicate getPredicate(arith::CmpIOp op) { + return op.getPredicate(); + } + + static bool gtPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sgt || + predicate == arith::CmpIPredicate::ugt; + } + + static bool gePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sge || + predicate == arith::CmpIPredicate::uge; + } + + static bool ltPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::slt || + predicate == arith::CmpIPredicate::ult; + } + + static bool lePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sle || + predicate == arith::CmpIPredicate::ule; + } + + static bool compare(arith::CmpIPredicate predicate, int64_t lhs, + int64_t rhs) { + switch (predicate) { + case arith::CmpIPredicate::eq: + return lhs == rhs; + case arith::CmpIPredicate::ne: + return lhs != rhs; + case arith::CmpIPredicate::slt: + return lhs < rhs; + case arith::CmpIPredicate::sle: + return lhs <= rhs; + case arith::CmpIPredicate::sgt: + return lhs > rhs; + case arith::CmpIPredicate::sge: + return lhs >= rhs; + case arith::CmpIPredicate::ult: + return (uint64_t)lhs < (uint64_t)rhs; + case arith::CmpIPredicate::ule: + return (uint64_t)lhs <= (uint64_t)rhs; + case arith::CmpIPredicate::ugt: + return (uint64_t)lhs > (uint64_t)rhs; + case arith::CmpIPredicate::uge: + return (uint64_t)lhs >= (uint64_t)rhs; + default: + break; + } + llvm_unreachable("unknown comparison predicate"); + } +}; + +template +class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto condConstancy = operands[0]->getValue().getConstancy(); + auto lhsInfo = operands[1]->getValue(); + auto rhsInfo = operands[2]->getValue(); + auto rank = lhsInfo.getRank(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + if (operands[0]->getValue().getConstantValue().has_value()) { + if (operands[0]->getValue().getConstantValue() == 0) { + contiguity = rhsInfo.getContiguity(); + divisibility = rhsInfo.getDivisibility(); + constancy = rhsInfo.getConstancy(); + constantValue = rhsInfo.getConstantValue(); + } else { + contiguity = lhsInfo.getContiguity(); + divisibility = lhsInfo.getDivisibility(); + constancy = lhsInfo.getConstancy(); + constantValue = lhsInfo.getConstantValue(); + } + } else { + // The condition can be either a tensor or i1. + // If i1 is used as the condition, the entire tensor of either + // lhs or rhs is selected. + bool i1Cond = isa(op.getOperand(0).getType()); + for (auto d = 0; d < rank; ++d) { + if (i1Cond) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } else { + constancy.push_back( + std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]), + gcd(rhsInfo.getConstancy(d), condConstancy[d]))); + contiguity.push_back( + std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]), + gcd(rhsInfo.getContiguity(d), condConstancy[d]))); + if (contiguity.back() == lhsInfo.getContiguity(d) && + contiguity.back() == rhsInfo.getContiguity(d)) { + // Contiguity not changed + divisibility.push_back( + gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + } else { + // Contiguity changed, we cannot use only divisibility. + // For example, the following example should have contiguity 2 and + // divisibility 2 + // [[0, 1], [4, 5]] + // [[16, 17, 18, 19]] + divisibility.push_back( + std::min(gcd(lhsInfo.getDivisibility(d), contiguity.back()), + gcd(rhsInfo.getDivisibility(d), contiguity.back()))); + } + } + } + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value() && + lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) + constantValue = lhsInfo.getConstantValue(); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } +}; + +template +class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() & + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() | + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() ^ + rhs.getConstantValue().value()}; + } + } + return {}; + } +}; + +class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto shift = rhs.getConstantValue().value_or(0); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto numBits = log2Int(lhsDivisibility); + return multiplyDivisor(lhsDivisibility, 1 << shift); + } + + int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() << rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (!rhs.getConstantValue().has_value()) + return 1; + auto shift = rhs.getConstantValue().value(); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + return std::max(1, lhsDivisibility / (1 << shift)); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + std::optional constantValue; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::max(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } else if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::min(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } + return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1), + /*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1), + /*knownConstancy=*/AxisInfo::DimVectorT(rank, 1), + /*constantValue=*/constantValue); + } else { + AxisInfo::DimVectorT contiguity, divisibility, constancy; + for (auto d = 0; d < rank; ++d) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } + return AxisInfo(contiguity, divisibility, constancy, std::nullopt); + } + } +}; + +class MakeTensorPtrOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::MakeTensorPtrOp op, + ArrayRef *> operands) override { + LDBG("MakeTensorPtrOpAxisInfoVisitor: " << *op); + + AxisInfo ptrInfo = operands[0]->getValue(); + AxisInfo shapeInfo0 = operands[1]->getValue(); + AxisInfo shapeInfo1 = operands[2]->getValue(); + AxisInfo strideInfo0 = operands[3]->getValue(); + AxisInfo strideInfo1 = operands[4]->getValue(); + + std::optional shape0 = shapeInfo0.getConstantValue(); + std::optional shape1 = shapeInfo1.getConstantValue(); + std::optional stride0 = strideInfo0.getConstantValue(); + std::optional stride1 = strideInfo1.getConstantValue(); + + auto isOne = [](std::optional value) { + return value.has_value() && value.value() == 1; + }; + + AxisInfo::DimVectorT contiguity{ + shape0.has_value() && isOne(stride0) ? shape0.value() : 1, + shape1.has_value() && isOne(stride1) ? shape1.value() : 1}; + + int64_t ptrDivisibility = ptrInfo.getDivisibility()[0]; + int64_t strideDivisibility0 = strideInfo0.getDivisibility()[0]; + int64_t strideDivisibility1 = strideInfo1.getDivisibility()[0]; + + LDBG("ptrDivisibility: " << ptrDivisibility); + LDBG("strideDivisibility0: " << strideDivisibility0); + LDBG("strideDivisibility1: " << strideDivisibility1); + + AxisInfo::DimVectorT divisibility{1, 1}; + if (ptrDivisibility > 1) { + if (contiguity[0] > 1) + divisibility[0] = std::min(ptrDivisibility, strideDivisibility1); + if (contiguity[1] > 1) + divisibility[1] = std::min(ptrDivisibility, strideDivisibility0); + } + + AxisInfo::DimVectorT constancy{1, 1}; + + return AxisInfo(contiguity, divisibility, constancy); + } +}; + +class AdvanceOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + AxisInfo + getAxisInfo(triton::AdvanceOp op, + ArrayRef *> operands) override { + return operands[0]->getValue(); + } +}; + +//===----------------------------------------------------------------------===// +// AxisInfoAnalysis +//===----------------------------------------------------------------------===// + +AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) + : dataflow::SparseForwardDataFlowAnalysis>( + solver) { + // UnrealizedConversionCast: + // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is + // in the process of a PartialConversion, where UnrealizedConversionCast + // may exist + visitors.append, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor>(); + // TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp + // when scf.for supports integer induction variables + visitors.append(); + visitors.append, + ConstantOpAxisInfoVisitor>(); + visitors.append, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor>(); + visitors.append(); + visitors.append, + DivOpAxisInfoVisitor>(); + visitors.append, + RemOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append>(); + visitors.append, + LogicalOpAxisInfoVisitor, + LogicalOpAxisInfoVisitor>(); + visitors.append>(); + visitors.append, + ShROpAxisInfoVisitor>(); + visitors.append, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); +} + +LogicalResult AxisInfoAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + LDBG("visitOperation: << " << *op); + // TODO: For sure not the right way to do this + // but why is scf.if not initialized otherwise? + for (auto op : operands) + if (op->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op); + AxisInfo curr = visitors.apply(op, operands); + if (curr.getRank() == 0) { + setAllToEntryStates(results); + return success(); + } + // override with hint + auto newContiguity = curr.getContiguity(); + auto newDivisibility = curr.getDivisibility(); + auto newConstancy = curr.getConstancy(); + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + curr = AxisInfo(newContiguity, newDivisibility, newConstancy, + curr.getConstantValue()); + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(curr)); + return success(); +} + +void AxisInfoAnalysis::visitForOpInductionVar( + scf::ForOp op, ArrayRef *> argLattices) { + auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue(); + auto step = getLatticeElementFor(op, op.getStep())->getValue(); + + AxisInfo::DimVectorT knownContiguity(1, 1); + AxisInfo::DimVectorT knownDivisibility(1, 1); + AxisInfo::DimVectorT knownConstancy(1, 1); + knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0)); + auto inductionVar = + AxisInfo(knownContiguity, knownDivisibility, knownConstancy); + (void)argLattices[0]->join(inductionVar); +} + +} // anonymous namespace + +template +void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy) { + // liast of attributes that we care about + SmallVector> retVecs; + retVecs.push_back({contiguity, "tt.contiguity"}); + retVecs.push_back({divisibility, "tt.divisibility"}); + retVecs.push_back({constancy, "tt.constancy"}); + // initialize attributes one by one + for (auto [vec, attrName] : retVecs) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto int_attr = dyn_cast_or_null(attr)) + *vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue()); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + *vec = DimVectorT(vals.begin(), vals.end()); + } + } +} + +/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { + auto rank = 1; + if (TensorType ty = dyn_cast(value.getType())) + rank = ty.getRank(); + if (triton::PointerType ty = dyn_cast(value.getType())) + if (TensorType elemTy = dyn_cast(ty.getPointeeType())) + rank = elemTy.getRank(); + + DimVectorT knownContiguity(rank, 1); + DimVectorT knownDivisibility(rank, 1); + DimVectorT knownConstancy(rank, 1); + + BlockArgument blockArg = dyn_cast(value); + + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + // llvm codegen check alignment to generate vector load/store + // would be nice if this wasn't the case + else if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + else if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. + knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); + knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); + knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); + } + } else if (Operation *op = value.getDefiningOp()) { + if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. + knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); + knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); + knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); + } + // Other operations are conservatively initialized with the lowest possible + // divisibility, contiguity, and constancy unless they have specified. + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + knownDivisibility = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + knownContiguity = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + knownConstancy = DimVectorT(vals.begin(), vals.end()); + } + } + + return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); +} + +/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { + // If one argument is not initialized, return the other. + if (lhs.getRank() == 0) + return rhs; + if (rhs.getRank() == 0) + return lhs; + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + for (auto d = 0; d < lhs.getRank(); ++d) { + contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); + divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); + } + std::optional constantValue; + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value() && + lhs.getConstantValue() == rhs.getConstantValue()) + constantValue = lhs.getConstantValue(); + return AxisInfo(contiguity, divisibility, constancy, constantValue); +} + +unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { + Type ptrTy = ptr.getType(); + auto tensorTy = + isTensorPointerType(ptrTy) + ? cast(cast(ptrTy).getPointeeType()) + : dyn_cast(ptrTy); + if (!tensorTy) + return 1; + auto layout = tensorTy.getEncoding(); + + // Here order should be ordered by contiguous first, so the first element + // should have the largest contiguous. + auto order = triton::gpu::getOrder(layout); + unsigned align = getPtrAlignment(ptr); + + auto uniqueContigPerThread = + triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + LDBG("getPtrContiguity uniqueContigPerThread = " << contiguity); + contiguity = std::min(align, contiguity); + + return contiguity; +} + +unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { + Type ptrTy = ptr.getType(); + auto tensorTy = + isTensorPointerType(ptrTy) + ? cast(cast(ptrTy).getPointeeType()) + : dyn_cast(ptrTy); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(ptr); + if (!axisInfo) + return 1; + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); + auto maxContig = axisInfo->getContiguity(order[0]); + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); + unsigned alignment = std::min(maxMultiple, maxContig); + LDBG("getPtrAlignment order[0] " + << order[0] << " maxMultipleBytes = " << maxMultipleBytes + << " maxContig = " << maxContig << " elemNumBits = " << elemNumBits + << " maxMultiple = " << maxMultiple << " alignment " << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { + Type ptrTy = mask.getType(); + auto tensorTy = + isTensorPointerType(ptrTy) + ? cast(cast(ptrTy).getPointeeType()) + : dyn_cast(ptrTy); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(mask); + if (!axisInfo) + return 1; + auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); + auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); + LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " + << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(funcOp))) + return; + auto *axisInfoMap = getFuncData(funcOp); + auto updateAxisInfoMap = [&](Value value) { + auto axisInfo = analysis->getLatticeElement(value)->getValue(); + AxisInfo curAxisInfo; + if (axisInfoMap->count(value)) { + curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value)); + } else { + curAxisInfo = axisInfo; + } + (*axisInfoMap)[value] = curAxisInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateAxisInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateAxisInfoMap(value); + } + }); +} + +void ModuleAxisInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *axisInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, int64_t prevValue) { + auto curValue = highestPowOf2Divisor(0); + if (callee.getArgAttrOfType(index, attrName)) { + curValue = + callee.getArgAttrOfType(index, attrName).getInt(); + } + auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64), + gcd(prevValue, curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto axisInfo = axisInfoMap->lookup(value); + assert(axisInfo.getRank() == 1 && "only scalar arguments are supported"); + setAttrFn("tt.contiguity", axisInfo.getContiguity(0)); + setAttrFn("tt.divisibility", axisInfo.getDivisibility(0)); + setAttrFn("tt.constancy", axisInfo.getConstancy(0)); + } +} + +} // namespace mlir::triton::intel diff --git a/third_party/intel/lib/Analysis/CMakeLists.txt b/third_party/intel/lib/Analysis/CMakeLists.txt index baf1c98656..e51b359137 100644 --- a/third_party/intel/lib/Analysis/CMakeLists.txt +++ b/third_party/intel/lib/Analysis/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonIntelAnalysis + AxisInfo.cpp DPAS.cpp Liveness.cpp Utility.cpp From 4dc1cf189c756d5498651214a2ad0c8246fa767c Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 16 Oct 2024 20:52:03 +0000 Subject: [PATCH 02/32] WIP: Coalescing for block ptrs Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/coalesce.mlir | 274 +++++++++++++ test/lib/Analysis/CMakeLists.txt | 1 - third_party/intel/backend/compiler.py | 4 +- .../include/Dialect/TritonIntelGPU/IR/Utils.h | 31 +- .../TritonIntelGPU/Transforms/Passes.td | 17 + .../TritonIntelGPUTransforms/CMakeLists.txt | 1 + .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 385 ++++++++++++++++++ third_party/intel/triton_xpu.cc | 8 +- 8 files changed, 712 insertions(+), 9 deletions(-) create mode 100644 test/TritonIntelGPU/coalesce.mlir create mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp diff --git a/test/TritonIntelGPU/coalesce.mlir b/test/TritonIntelGPU/coalesce.mlir new file mode 100644 index 0000000000..1429dc2065 --- /dev/null +++ b/test/TritonIntelGPU/coalesce.mlir @@ -0,0 +1,274 @@ +// RUN: triton-opt %s -split-input-file -tritonintelgpu-coalesce | FileCheck %s + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> +#slice2dim0 = #triton_gpu.slice<{dim = 0, parent = #blocked2}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + +// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> +// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]> +// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]> +// CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] : tensor<64x64x!tt.ptr, [[row_layout]]> +// CHECK: [[store_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[col_layout]]> +// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> +// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> +// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] +tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense : tensor<64x64xi1, #blocked1> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> + %00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1> + %01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0> + %1 = tt.expand_dims %00 {axis = 1 : i32} : tensor<64xi32, #slice1dim1> -> tensor<64x1xi32, #blocked1> + %2 = tt.splat %arg1 : i32 -> tensor<64x1xi32, #blocked1> + %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> + %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %9 = triton_gpu.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %11 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %13 = tt.splat %arg3 : i32 -> tensor<1x64xi32, #blocked2> + %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> + %15 = tt.broadcast %12 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %16 = tt.broadcast %14 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %17 = triton_gpu.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %19 = tt.load %10, %cst, %cst_0 : tensor<64x64x!tt.ptr, #blocked1> + tt.store %18, %19, %cst : tensor<64x64x!tt.ptr, #blocked1> + tt.return +} + +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + + +// CHECK: [[NARROW_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: [[WIDE_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> + %13 = arith.extf %12 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked> + %14 = arith.addf %9, %13 : tensor<1024xf32, #blocked> + %15 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: tt.store {{.*}} : tensor<1024x!tt.ptr, [[WIDE_LAYOUT]]> + tt.store %16, %14, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return +} + +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-NOT: sizePerThread = [4] +// CHECK: #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK-NOT: sizePerThread = [4] +tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> + %13 = arith.extf %12 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked> + %14 = arith.addf %9, %13 : tensor<1024xf32, #blocked> + %15 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %17 = arith.truncf %14 : tensor<1024xf32, #blocked> to tensor<1024xf16, #blocked> + tt.store %16, %17, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return +} + +} + +// ----- + +// COM: Reproducer for issue #3866 +// CHECK-LABEL: @test_3866 +// CHECK: tt.load {{.*}} : !tt.ptr +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + tt.func public @test_3866(%arg0: !tt.ptr, %arg1: i32, %arg2: i64) { + %0 = tt.make_tensor_ptr %arg0, [%arg2, %arg2], [%arg2, %arg2], [%arg1, %arg1] {order = array} : > + %1 = tt.load %0 : !tt.ptr> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> +#dot1 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}> +#dot2 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK: [[BLOCKED_LAYOUT1:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 8], warpsPerCTA = [4, 1], order = [1, 0]}> + // CHECK: [[BLOCKED_LAYOUT2:#.*]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 4], warpsPerCTA = [1, 4], order = [0, 1]}> + // CHECK: @test_block_ptrs + tt.func public @test_block_ptrs(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32, %arg20: i32) { + %cst = arith.constant dense<0.000000e+00> : tensor<8x16xf32, #dpas> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<8xf32, #blocked> + %cst_1 = arith.constant dense<0xFF800000> : tensor<8xf32, #blocked> + %c1_i32 = arith.constant 1 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<8x64xf32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.divsi %1, %arg19 : i32 + %3 = arith.remsi %1, %arg19 : i32 + %4 = arith.extsi %2 : i32 to i64 + %5 = arith.extsi %arg6 : i32 to i64 + %6 = arith.muli %4, %5 : i64 + %7 = arith.extsi %3 : i32 to i64 + %8 = arith.extsi %arg7 : i32 to i64 + %9 = arith.muli %7, %8 : i64 + %10 = arith.addi %6, %9 : i64 + %11 = tt.addptr %arg0, %10 : !tt.ptr, i64 + %12 = arith.muli %0, %c8_i32 : i32 + %13 = arith.extsi %arg20 : i32 to i64 + %14 = arith.extsi %arg8 : i32 to i64 + // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : + %15 = tt.make_tensor_ptr %11, [%13, %c64_i64], [%14, %c1_i64], [%12, %c0_i32] {order = array} : > + %16 = tt.addptr %arg1, %10 : !tt.ptr, i64 + %17 = arith.extsi %arg11 : i32 to i64 + // CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr {{.*}} : + %18 = tt.make_tensor_ptr %16, [%c64_i64, %13], [%c1_i64, %17], [%c0_i32, %c0_i32] {order = array} : > + %19 = tt.addptr %arg5, %10 : !tt.ptr, i64 + %20 = arith.extsi %arg17 : i32 to i64 + // CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr {{.*}} : + %21 = tt.make_tensor_ptr %19, [%13, %c64_i64], [%20, %c1_i64], [%12, %c0_i32] {order = array} : > + %22 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %23 = tt.splat %12 : i32 -> tensor<8xi32, #blocked> + %24 = arith.addi %23, %22 : tensor<8xi32, #blocked> + // CHECK: [[LOAD1:%.*]] = tt.load [[PTR1]] : !tt.ptr + // CHECK-NEXT: triton_gpu.convert_layout [[LOAD1]] : tensor<8x64xf8E5M2, [[BLOCKED_LAYOUT1]]> -> tensor<8x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %25 = tt.load %15 : !tt.ptr> + %26 = arith.addi %0, %c1_i32 : i32 + %27 = arith.muli %26, %c8_i32 : i32 + // CHECK: [[ADVANCE1:%.*]] = tt.advance [[PTR2]], {{.*}} : > + %28 = tt.advance %18, [%c0_i32, %12] : > + // CHECK: [[RES:%.*:2]] = scf.for {{.*}} iter_args(%arg22 = %cst_1, %arg23 = [[ADVANCE1]]) -> (tensor<8xf32, #blocked>, !tt.ptr>) + %29:2 = scf.for %arg21 = %12 to %27 step %c16_i32 iter_args(%arg22 = %cst_1, %arg23 = %28) -> (tensor<8xf32, #blocked>, !tt.ptr>) : i32 { + // CHECK: [[LOAD2:%.*]] = tt.load %arg23 : !tt.ptr> + // CHECK-NEXT: triton_gpu.convert_layout [[LOAD2]] : tensor<64x16xf8E5M2, [[BLOCKED_LAYOUT2]]> -> tensor<64x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %36 = tt.load %arg23 : !tt.ptr> + %37 = tt.fp_to_fp %25 : tensor<8x64xf8E5M2, #dot1> -> tensor<8x64xf16, #dot1> + %38 = tt.fp_to_fp %36 : tensor<64x16xf8E5M2, #dot2> -> tensor<64x16xf16, #dot2> + %39 = tt.dot %37, %38, %cst, inputPrecision = tf32 : tensor<8x64xf16, #dot1> * tensor<64x16xf16, #dot2> -> tensor<8x16xf32, #dpas> + %40 = triton_gpu.convert_layout %39 : tensor<8x16xf32, #dpas> -> tensor<8x16xf32, #blocked2> + %41 = "tt.reduce"(%40) <{axis = 1 : i32}> ({ + ^bb0(%arg24: f32, %arg25: f32): + %44 = arith.maxnumf %arg24, %arg25 : f32 + tt.reduce.return %44 : f32 + }) : (tensor<8x16xf32, #blocked2>) -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %42 = triton_gpu.convert_layout %41 : tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<8xf32, #blocked> + // CHECK: [[ADVANCE2:%.*]] = tt.advance %arg23, {{.*}} : > + // CHECK-NEXT: scf.yield {{.*}}, [[ADVANCE2]] : tensor<8xf32, #blocked>, !tt.ptr> + %43 = tt.advance %arg23, [%c0_i32, %c16_i32] : > + scf.yield %42, %43 : tensor<8xf32, #blocked>, !tt.ptr> + } {tt.divisibility_arg1 = dense<16> : tensor<1xi32>} + %30 = arith.addf %29#0, %cst_0 : tensor<8xf32, #blocked> + %31 = arith.muli %1, %arg20 : i32 + %32 = tt.addptr %arg4, %31 : !tt.ptr, i32 + %33 = tt.splat %32 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %34 = tt.addptr %33, %24 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + tt.store %34, %30 : tensor<8x!tt.ptr, #blocked> + %35 = tt.fp_to_fp %cst_2, rounding = rtne : tensor<8x64xf32, #blocked1> -> tensor<8x64xf8E5M2, #blocked1> + tt.store %21, %35 : !tt.ptr> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}> +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> +#dot2 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK: [[BLOCKED_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> + // CHECK: @test_block_ptrs + tt.func public @test_block_ptrs(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg19: i32) { + %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>> + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.divsi %1, %arg19 : i32 + %3 = arith.remsi %1, %arg19 : i32 + %4 = arith.extsi %2 : i32 to i64 + %5 = arith.extsi %arg6 : i32 to i64 + %6 = arith.muli %4, %5 : i64 + %7 = arith.extsi %3 : i32 to i64 + %8 = arith.extsi %arg7 : i32 to i64 + %9 = arith.muli %7, %8 : i64 + %10 = arith.addi %6, %9 : i64 + %12 = arith.muli %0, %c64_i32 : i32 + %13 = arith.extsi %arg19 : i32 to i64 + %19 = tt.addptr %arg1, %10 : !tt.ptr, i64 + %20 = arith.extsi %arg11 : i32 to i64 + // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : + %21 = tt.make_tensor_ptr %19, [%c64_i64, %13], [%c1_i64, %20], [%c0_i32, %c0_i32] {order = array} : > + // CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args(%arg6 = %cst, %arg7 = [[PTR1]]) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr>) + %33:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg22 = %cst_1, %arg23 = %21) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr>) : i32 { + // CHECK: [[LOAD:%.*]] = tt.load %arg7 : !tt.ptr> + // CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]> -> tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK-NEXT: scf.yield %arg6, %arg7 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr> + %load = tt.load %arg23 : !tt.ptr> + scf.yield %arg22, %arg23 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr> + } + // CHECK: scf.for {{.*}} iter_args(%arg6 = [[RES]]#0, %arg7 = [[RES]]#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr>) + %34:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg22 = %33#0, %arg23 = %33#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr>) : i32 { + // CHECK: scf.yield %arg6, %arg7 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr> + scf.yield %arg22, %arg23 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr> + } + tt.return + } +} diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt index 75f785ce24..da7bc3f78a 100644 --- a/test/lib/Analysis/CMakeLists.txt +++ b/test/lib/Analysis/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_library(TritonTestAnalysis - intel/TestAxisInfo.cpp TestAlias.cpp TestAxisInfo.cpp TestAllocation.cpp diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 32c8c2e133..83a7202a73 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -217,10 +217,10 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) - intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) +# intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) - passes.ttgpuir.add_coalesce(pm) + intel.passes.ttgpuir.add_coalesce(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h index 4c0031e2dd..7950fe1377 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h @@ -9,17 +9,44 @@ #ifndef TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H #define TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H -#include - +#include "intel/include/Analysis/AxisInfo.h" +#include "mlir/IR/Operation.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include namespace mlir::triton::gpu::intel { + +/// Calculate the optimal number of elements per thread for a given operation +/// along an axis with greatest continuity. +inline unsigned getNumElementsPerThread( + Operation *op, SmallVector order, + mlir::triton::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + Value val = getMemAccessPtr(op); + Type valTy = val.getType(); + auto ty = + isTensorPointerType(valTy) + ? cast(cast(valTy).getPointeeType()) + : cast(valTy); + auto shapePerCTA = getShapePerCTA(ty); + mlir::triton::intel::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + + unsigned elemNumBits = getElementBitWidth(ty); + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); + unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); + unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); + unsigned maxContig = + std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); + unsigned alignment = std::min(maxMultiple, maxContig); + return std::min(alignment, 128 / elemNumBits); +} + /// Check whether transposed reduction should be performed. /// /// See: https://github.com/intel/intel-xpu-backend-for-triton/issues/1637 inline bool applyTransposedReduction() { return tools::getBoolEnv("TRITON_INTEL_REDUCE_TRANSPOSE"); } + } // namespace mlir::triton::gpu::intel #endif // TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 42e386fe29..f8659d2711 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -27,6 +27,23 @@ def TritonIntelGPUAccelerateMatmul ]; } +def TritonIntelGPUCoalesce + : Pass<"tritonintelgpu-coalesce", "mlir::ModuleOp"> { + let summary = "Intel Coalesce"; + + let description = [{ + The pass analyses loads/stores with type `tensor>` or + `tt.ptr>` and replaces the layouts of these operations with + coalesced layouts, i.e. cache friendly access patterns. + Layout conversions are inserted before and after the load/store op + to maintain consistency with the rest of the program. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::gpu::intel::TritonIntelGPUDialect"]; +} + def TritonIntelGPUDistributeToWarps : Pass<"tritonintelgpu-distribute-to-warps", "mlir::ModuleOp"> { let summary = "distribute the thread block workload to the warps"; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 8c2e290ada..9c02e5752c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonIntelGPUTransforms AccelerateMatmul.cpp + Coalesce.cpp DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp new file mode 100644 index 0000000000..05f10e2efc --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -0,0 +1,385 @@ +#include "intel/include/Analysis/AxisInfo.h" +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "tritonintelgpu-coalesce" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton::gpu::intel { +#define GEN_PASS_DEF_TRITONINTELGPUCOALESCE +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" +} // namespace mlir::triton::gpu::intel + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttgi = mlir::triton::gpu::intel; + +namespace { + +struct CoalescePass + : public ttgi::impl::TritonIntelGPUCoalesceBase { + void + setCoalescedEncoding(tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis, + Operation *op, int numWarps, int threadsPerWarp, + llvm::MapVector &layoutMap) { + Value ptr = getMemAccessPtr(op); + LDBG("ptr: " << ptr); + + LDBG("Considering op: " << *op); + LLVM_DEBUG({ + DBGS() << "axis info of pointer: "; + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); + SmallVector order = argSort(contiguity); + LDBG("order=[" << triton::join(order, ", ") << "]"); + + RankedTensorType refTensorType = + tt::isTensorPointerType(ptr.getType()) + ? cast( + cast(ptr.getType()).getPointeeType()) + : cast(ptr.getType()); + + auto matchesShape = [&refTensorType](const Value &val) { + auto rttType = dyn_cast(val.getType()); + return rttType && rttType.getShape() == refTensorType.getShape(); + }; + + // The desired divisibility is the maximum divisibility among all dependent + // pointers which have the same shape and order as `ptr`. + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); + if (ptr.getDefiningOp()) { + for (Operation *use : mlir::multiRootGetSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use)) + continue; + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); + memAccessesSameOrder.insert(use); + } + } + } + + auto shapePerCTA = ttg::getShapePerCTA(refTensorType); + LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); + + int numElems = product(shapePerCTA); + int numThreads = numWarps * threadsPerWarp; + unsigned perThread = + ttgi::getNumElementsPerThread(op, order, axisInfoAnalysis); + LDBG("perThread for op: " << perThread); + + for (Operation *opSameOrder : memAccessesSameOrder) { + if (opSameOrder == op) + continue; + unsigned currPerThread = + ttgi::getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis); + LDBG("perThread for opSameOrder: " << currPerThread); + perThread = std::max(perThread, currPerThread); + } + + perThread = std::min(perThread, std::max(numElems / numThreads, 1)); + LDBG("perThread: " << perThread); + + if (perThread <= 1) + return; + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + perThread = std::min(perThread, ttgi::getNumElementsPerThread( + op, order, axisInfoAnalysis)); + } + SmallVector sizePerThread(refTensorType.getRank(), 1); + sizePerThread[order[0]] = perThread; + + auto CTALayout = ttg::getCTALayout(refTensorType.getEncoding()); + layoutMap[op] = ttg::BlockedEncodingAttr::get( + &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, + threadsPerWarp, CTALayout); + } + + static RankedTensorType getNewType(RankedTensorType tensorType, + Attribute encoding) { + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + } + + // Find the defining makeTensorPtrOp operation of the given value. + static std::optional + findDefiningMakeTensorPtrOp(Value val) { + if (auto arg = dyn_cast(val)) { + Operation *parentOp = val.getParentBlock()->getParentOp(); + assert(isa(parentOp) && "Expected a scf::ForOp"); + auto loopArg = + cast(parentOp).getInitArgs()[arg.getArgNumber() - 1]; + return findDefiningMakeTensorPtrOp(loopArg); + } + + if (auto advanceOp = val.getDefiningOp()) + return findDefiningMakeTensorPtrOp(advanceOp.getPtr()); + if (auto makePtrOp = val.getDefiningOp()) + return makePtrOp; + + return std::nullopt; + } + + static bool filterUser(Operation *op) { + // Yield operations trigger updating the layout of the containing loop + // results, so don't skip them. + if (isa(op)) + return false; + + // Skip operations that don't yield a result and contain no regions. + if (op->getNumResults() == 0 && op->getNumRegions() == 0) + return true; + + // Operations that do not yield a block pointer aren't interesting. + if (op->getNumRegions() == 0 && + llvm::none_of(op->getResultTypes(), [](Type resType) { + return tt::isTensorPointerType(resType); + })) + return true; + + return false; + } + + // Propagate the \p root block argument operation output layout along the + // def-use chain. + static void propagateLayout(BlockArgument arg, Attribute layout, + IRRewriter &rewriter) { + llvm::errs() << "arg: " << arg << "\n"; + for (Operation *user : arg.getUsers()) { + llvm::errs() << "user: " << *user << "\n\n"; + if (filterUser(user)) { + llvm::errs() << "SKIP\n"; + continue; + } + + if (auto yieldOp = dyn_cast(user)) { + // Modify and propagate the result of the enclosing loop. + auto forOp = yieldOp->getParentOfType(); + changeAndPropagateLayout(forOp, layout, rewriter); + continue; + } + + changeAndPropagateLayout(user, layout, rewriter); + } + } + + static void propagateLayout(Operation *root, Attribute layout, + IRRewriter &rewriter) { + assert(root && root->getNumResults() != 0 && + "Expecting an operation yielding a result"); + + // llvm::errs() << "root: " << *root << "\n\n"; + for (Operation *user : root->getUsers()) { + llvm::errs() << "user: " << *user << "\n\n"; + if (filterUser(user)) { + llvm::errs() << "SKIP\n"; + continue; + } + + if (auto yieldOp = dyn_cast(user)) { + // Modify and propagate the result of the enclosing loop. + auto forOp = yieldOp->getParentOfType(); + changeAndPropagateLayout(forOp, layout, rewriter); + continue; + } + + if (auto forOp = dyn_cast(user)) { + for (BlockArgument arg : forOp.getRegionIterArgs()) { + Value loopArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; + for (OpResult res : root->getResults()) { + if (res == loopArg && tt::isTensorPointerType(res.getType())) { + llvm::errs() << "arg: " << arg << "\n"; + llvm::errs() << "loopArg: " << loopArg << "\n"; + llvm::errs() << "arg type: " << arg.getType() << "\n"; + + // Modify the layout of the loop init argument... + tt::PointerType ptrType = cast(arg.getType()); + auto tensorType = + cast(ptrType.getPointeeType()); + arg.setType(tt::PointerType::get(getNewType(tensorType, layout), + ptrType.getAddressSpace())); + + // ... and then propagate it to the operations in the loop. + propagateLayout(arg, layout, rewriter); + } + } + } + continue; + } + + changeAndPropagateLayout(user, layout, rewriter); + } + } + + // Change the \p layout of the \p op result(s) and propagate the new result + // type to its users. + static void changeAndPropagateLayout(Operation *op, Attribute layout, + IRRewriter &rewriter) { + assert(op && op->getNumResults() != 0 && + "Expecting operation yielding a result"); + + rewriter.modifyOpInPlace(op, [&]() { + for (Value res : op->getResults()) { + if (!tt::isTensorPointerType(res.getType())) + continue; + + auto ptrType = cast(res.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + res.setType(tt::PointerType::get(getNewType(tensorType, layout), + ptrType.getAddressSpace())); + } + }); + llvm::errs() << "Coalesced op: " << *op << "\n"; + + propagateLayout(op, layout, rewriter); + } + + void coalesceOp(Attribute encoding, Operation *op) { + llvm::errs() << "Coalescing op: " << *op << "\n"; + + OpBuilder builder(op); + IRRewriter rewriter(builder); + + // Convert operands + // Note: for load/store with a blocked pointers argument we cannot change + // the operand type, instead we change the output type of + // `make_tensor_ptr` and propagate the new output type along the def-use + // chain. + SmallVector newArgs; + for (Value operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + RankedTensorType newType = getNewType(tensorType, encoding); + newArgs.push_back(rewriter.create( + op->getLoc(), newType, operand)); + } else { + assert(isa(operand.getType()) && + "Expecting operand to have blocked pointer type"); + auto defOp = findDefiningMakeTensorPtrOp(operand); + assert(defOp && "Expected a make_tensor_ptr operation"); + + llvm::errs() << "Found make_tensor_ptr definition: " << *defOp << "\n"; + changeAndPropagateLayout(*defOp, encoding, rewriter); + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + assert(!isAsync && + "AsyncCopyGlobalToLocalOp not supported for Intel GPU"); + newTypes.push_back(getNewType(cast(t), encoding)); + } + + // Construct new op with the new encoding. + Operation *newOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); + + // Cast the results back to the original layout. + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = rewriter.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + + llvm::errs() << "newOp: " << *newOp << "\n"; + assert(succeeded(verify(newOp)) && "Operation verification failed"); + } + + void runOnOperation() override { + // Run axis info analysis + ModuleOp moduleOp = getOperation(); + tt::intel::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // For each i/o operation, we determine what layout + // the pointers should have for best memory coalescing + llvm::MapVector layoutMap; + moduleOp.walk([&](Operation *curr) { + Value ptr = getMemAccessPtr(curr); + if (!ptr) + return; + + RankedTensorType refTensorType = + tt::isTensorPointerType(ptr.getType()) + ? cast( + cast(ptr.getType()).getPointeeType()) + : dyn_cast(ptr.getType()); + if (!refTensorType || !refTensorType.getEncoding()) + return; + + int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, + layoutMap); + }); + + llvm::errs() << "layoutMap:\n"; + for (auto [op, encoding] : layoutMap) { + llvm::errs() << "op: " << *op << "\n"; + llvm::errs() << "encoding: " << encoding << "\n"; + } + llvm::errs() << "\n"; + + // For each memory op that has a layout L1: + // 1. Create a coalesced memory layout L2 of the pointer operands + // 2. Convert all operands from layout L1 to layout L2 + // 3. Create a new memory op that consumes these operands and + // produces a tensor with layout L2 + // 4. Convert the output of this new memory op back to L1 + // 5. Replace all the uses of the original memory op by the new one + for (auto [op, layout] : layoutMap) { + coalesceOp(layout, op); + if (failed(verify(moduleOp))) { + for (Operation &op1 : moduleOp.getOps()) { + if (isa(op1)) { + for (Operation &op2 : cast(op1).getOps()) { + if (failed(verify(&op2))) { + llvm::errs() << "op2: " << op2 << "\n"; + llvm::errs() << "Operation verification failed.\n"; + } + } + } + } + llvm::errs() << "Module verification failed.\n"; + llvm::errs() << "mod: " << moduleOp << "\n"; + assert(false); + } + llvm::errs() << "Module verified.\n"; + } + } +}; + +} // namespace diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index eb7be1c080..5b71d55819 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -45,10 +45,9 @@ using ret = py::return_value_policy; pm.addPass(builder({val0, val1})); \ }) #define ADD_PASS_WRAPPER_OPT_5(name, builder, ty0, ty1, ty2, ty3, ty4) \ - m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ - ty3 val3, ty4 val4) { \ - pm.addPass(builder({val0, val1, val2, val3, val4})); \ - }) + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \ + ty4 val4) { pm.addPass(builder({val0, val1, val2, val3, val4})); }) static uint32_t findKernels(llvm::Module &M, std::set &functions) { @@ -82,6 +81,7 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createTritonIntelGPURemoveLayoutConversions); ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", gpu::intel::createTritonIntelGPURewriteTensorPointer); + ADD_PASS_WRAPPER_0("add_coalesce", gpu::intel::createTritonIntelGPUCoalesce); ADD_PASS_WRAPPER_OPT_2("add_prefetch_block", gpu::intel::createTritonIntelGPUPrefetchBlock, int, bool); From fa53ced81c7a172624898b36cf5e4759b9e51590 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 16 Oct 2024 20:54:48 +0000 Subject: [PATCH 03/32] Fix pre_commit Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/coalesce.mlir | 78 +++++++++++++-------------- third_party/intel/backend/compiler.py | 2 +- third_party/intel/triton_xpu.cc | 7 +-- 3 files changed, 44 insertions(+), 43 deletions(-) diff --git a/test/TritonIntelGPU/coalesce.mlir b/test/TritonIntelGPU/coalesce.mlir index 1429dc2065..47e6e5df29 100644 --- a/test/TritonIntelGPU/coalesce.mlir +++ b/test/TritonIntelGPU/coalesce.mlir @@ -146,49 +146,49 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: [[BLOCKED_LAYOUT2:#.*]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 4], warpsPerCTA = [1, 4], order = [0, 1]}> // CHECK: @test_block_ptrs tt.func public @test_block_ptrs(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32, %arg20: i32) { - %cst = arith.constant dense<0.000000e+00> : tensor<8x16xf32, #dpas> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<8xf32, #blocked> - %cst_1 = arith.constant dense<0xFF800000> : tensor<8xf32, #blocked> - %c1_i32 = arith.constant 1 : i32 - %c16_i32 = arith.constant 16 : i32 - %cst_2 = arith.constant dense<0.000000e+00> : tensor<8x64xf32, #blocked1> - %c0_i32 = arith.constant 0 : i32 - %c1_i64 = arith.constant 1 : i64 - %c64_i64 = arith.constant 64 : i64 - %c8_i32 = arith.constant 8 : i32 - %0 = tt.get_program_id x : i32 - %1 = tt.get_program_id y : i32 - %2 = arith.divsi %1, %arg19 : i32 - %3 = arith.remsi %1, %arg19 : i32 - %4 = arith.extsi %2 : i32 to i64 - %5 = arith.extsi %arg6 : i32 to i64 - %6 = arith.muli %4, %5 : i64 - %7 = arith.extsi %3 : i32 to i64 - %8 = arith.extsi %arg7 : i32 to i64 - %9 = arith.muli %7, %8 : i64 - %10 = arith.addi %6, %9 : i64 - %11 = tt.addptr %arg0, %10 : !tt.ptr, i64 - %12 = arith.muli %0, %c8_i32 : i32 - %13 = arith.extsi %arg20 : i32 to i64 - %14 = arith.extsi %arg8 : i32 to i64 + %cst = arith.constant dense<0.000000e+00> : tensor<8x16xf32, #dpas> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<8xf32, #blocked> + %cst_1 = arith.constant dense<0xFF800000> : tensor<8xf32, #blocked> + %c1_i32 = arith.constant 1 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<8x64xf32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.divsi %1, %arg19 : i32 + %3 = arith.remsi %1, %arg19 : i32 + %4 = arith.extsi %2 : i32 to i64 + %5 = arith.extsi %arg6 : i32 to i64 + %6 = arith.muli %4, %5 : i64 + %7 = arith.extsi %3 : i32 to i64 + %8 = arith.extsi %arg7 : i32 to i64 + %9 = arith.muli %7, %8 : i64 + %10 = arith.addi %6, %9 : i64 + %11 = tt.addptr %arg0, %10 : !tt.ptr, i64 + %12 = arith.muli %0, %c8_i32 : i32 + %13 = arith.extsi %arg20 : i32 to i64 + %14 = arith.extsi %arg8 : i32 to i64 // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : %15 = tt.make_tensor_ptr %11, [%13, %c64_i64], [%14, %c1_i64], [%12, %c0_i32] {order = array} : > - %16 = tt.addptr %arg1, %10 : !tt.ptr, i64 - %17 = arith.extsi %arg11 : i32 to i64 + %16 = tt.addptr %arg1, %10 : !tt.ptr, i64 + %17 = arith.extsi %arg11 : i32 to i64 // CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr {{.*}} : %18 = tt.make_tensor_ptr %16, [%c64_i64, %13], [%c1_i64, %17], [%c0_i32, %c0_i32] {order = array} : > - %19 = tt.addptr %arg5, %10 : !tt.ptr, i64 - %20 = arith.extsi %arg17 : i32 to i64 + %19 = tt.addptr %arg5, %10 : !tt.ptr, i64 + %20 = arith.extsi %arg17 : i32 to i64 // CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr {{.*}} : %21 = tt.make_tensor_ptr %19, [%13, %c64_i64], [%20, %c1_i64], [%12, %c0_i32] {order = array} : > %22 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> - %23 = tt.splat %12 : i32 -> tensor<8xi32, #blocked> - %24 = arith.addi %23, %22 : tensor<8xi32, #blocked> + %23 = tt.splat %12 : i32 -> tensor<8xi32, #blocked> + %24 = arith.addi %23, %22 : tensor<8xi32, #blocked> // CHECK: [[LOAD1:%.*]] = tt.load [[PTR1]] : !tt.ptr // CHECK-NEXT: triton_gpu.convert_layout [[LOAD1]] : tensor<8x64xf8E5M2, [[BLOCKED_LAYOUT1]]> -> tensor<8x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %25 = tt.load %15 : !tt.ptr> - %26 = arith.addi %0, %c1_i32 : i32 - %27 = arith.muli %26, %c8_i32 : i32 + %26 = arith.addi %0, %c1_i32 : i32 + %27 = arith.muli %26, %c8_i32 : i32 // CHECK: [[ADVANCE1:%.*]] = tt.advance [[PTR2]], {{.*}} : > %28 = tt.advance %18, [%c0_i32, %12] : > // CHECK: [[RES:%.*:2]] = scf.for {{.*}} iter_args(%arg22 = %cst_1, %arg23 = [[ADVANCE1]]) -> (tensor<8xf32, #blocked>, !tt.ptr>) @@ -202,8 +202,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %40 = triton_gpu.convert_layout %39 : tensor<8x16xf32, #dpas> -> tensor<8x16xf32, #blocked2> %41 = "tt.reduce"(%40) <{axis = 1 : i32}> ({ ^bb0(%arg24: f32, %arg25: f32): - %44 = arith.maxnumf %arg24, %arg25 : f32 - tt.reduce.return %44 : f32 + %44 = arith.maxnumf %arg24, %arg25 : f32 + tt.reduce.return %44 : f32 }) : (tensor<8x16xf32, #blocked2>) -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %42 = triton_gpu.convert_layout %41 : tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<8xf32, #blocked> // CHECK: [[ADVANCE2:%.*]] = tt.advance %arg23, {{.*}} : > @@ -219,7 +219,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.store %34, %30 : tensor<8x!tt.ptr, #blocked> %35 = tt.fp_to_fp %cst_2, rounding = rtne : tensor<8x64xf32, #blocked1> -> tensor<8x64xf8E5M2, #blocked1> tt.store %21, %35 : !tt.ptr> - tt.return + tt.return } } @@ -254,19 +254,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %13 = arith.extsi %arg19 : i32 to i64 %19 = tt.addptr %arg1, %10 : !tt.ptr, i64 %20 = arith.extsi %arg11 : i32 to i64 - // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : + // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : %21 = tt.make_tensor_ptr %19, [%c64_i64, %13], [%c1_i64, %20], [%c0_i32, %c0_i32] {order = array} : > // CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args(%arg6 = %cst, %arg7 = [[PTR1]]) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr>) %33:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg22 = %cst_1, %arg23 = %21) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr>) : i32 { // CHECK: [[LOAD:%.*]] = tt.load %arg7 : !tt.ptr> - // CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]> -> tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]> -> tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> // CHECK-NEXT: scf.yield %arg6, %arg7 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr> %load = tt.load %arg23 : !tt.ptr> scf.yield %arg22, %arg23 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr> } // CHECK: scf.for {{.*}} iter_args(%arg6 = [[RES]]#0, %arg7 = [[RES]]#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr>) %34:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg22 = %33#0, %arg23 = %33#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr>) : i32 { - // CHECK: scf.yield %arg6, %arg7 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr> + // CHECK: scf.yield %arg6, %arg7 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr> scf.yield %arg22, %arg23 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr> } tt.return diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 83a7202a73..2184584ba2 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -217,7 +217,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) -# intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) + # intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) intel.passes.ttgpuir.add_coalesce(pm) diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 5b71d55819..31262579fc 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -45,9 +45,10 @@ using ret = py::return_value_policy; pm.addPass(builder({val0, val1})); \ }) #define ADD_PASS_WRAPPER_OPT_5(name, builder, ty0, ty1, ty2, ty3, ty4) \ - m.def(name, \ - [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \ - ty4 val4) { pm.addPass(builder({val0, val1, val2, val3, val4})); }) + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3, ty4 val4) { \ + pm.addPass(builder({val0, val1, val2, val3, val4})); \ + }) static uint32_t findKernels(llvm::Module &M, std::set &functions) { From 5a6cf81f74699096110395fd4f21f1220234f7f6 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 17 Oct 2024 20:15:53 +0000 Subject: [PATCH 04/32] Fix functional problem and add lit test Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/coalesce.mlir | 74 +++++++++++++++++-- .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 61 ++++++++++++++- 2 files changed, 128 insertions(+), 7 deletions(-) diff --git a/test/TritonIntelGPU/coalesce.mlir b/test/TritonIntelGPU/coalesce.mlir index 47e6e5df29..8dfe6d7ebe 100644 --- a/test/TritonIntelGPU/coalesce.mlir +++ b/test/TritonIntelGPU/coalesce.mlir @@ -134,6 +134,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // ----- +// COM: Test coalescing on blocked pointers: coalescable load using block pointer in a SCF for loop. + #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -225,6 +227,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +// COM: Test coalescing on blocked pointers: loop results used by another loop. + #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}> #dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> #dot2 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}> @@ -256,19 +260,79 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %20 = arith.extsi %arg11 : i32 to i64 // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : %21 = tt.make_tensor_ptr %19, [%c64_i64, %13], [%c1_i64, %20], [%c0_i32, %c0_i32] {order = array} : > - // CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args(%arg6 = %cst, %arg7 = [[PTR1]]) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr>) + // CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[ARG1:%.*]] = %cst, [[ARG2:%.*]] = [[PTR1]]) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr>) %33:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg22 = %cst_1, %arg23 = %21) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr>) : i32 { - // CHECK: [[LOAD:%.*]] = tt.load %arg7 : !tt.ptr> + // CHECK: [[LOAD:%.*]] = tt.load [[ARG2]] : !tt.ptr> // CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]> -> tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - // CHECK-NEXT: scf.yield %arg6, %arg7 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr> + // CHECK-NEXT: scf.yield [[ARG1]], [[ARG2]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr> %load = tt.load %arg23 : !tt.ptr> scf.yield %arg22, %arg23 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr> } - // CHECK: scf.for {{.*}} iter_args(%arg6 = [[RES]]#0, %arg7 = [[RES]]#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr>) + // CHECK: scf.for {{.*}} iter_args([[ARG1:%.*]] = [[RES]]#0, [[ARG2:%.*]] = [[RES]]#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr>) %34:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg22 = %33#0, %arg23 = %33#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr>) : i32 { - // CHECK: scf.yield %arg6, %arg7 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr> + // CHECK: scf.yield [[ARG1]], [[ARG2]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr> scf.yield %arg22, %arg23 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr> } tt.return } } + +// ----- + +// COM: Test coalescing on blocked pointers: loop with 2 output blocked pointers. + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK: [[BLOCKED_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 4], warpsPerCTA = [1, 4], order = [0, 1]}> + // CHECK: @test_block_ptrs + tt.func public @test_block_ptrs(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32, %arg11: i32 {tt.divisibility = 16 : i32}, %arg14: i32, %arg19: i32, %arg20: i32) { + %c32_i32 = arith.constant 32 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.divsi %1, %arg19 : i32 + %3 = arith.remsi %1, %arg19 : i32 + %4 = arith.extsi %2 : i32 to i64 + %5 = arith.extsi %arg6 : i32 to i64 + %6 = arith.muli %4, %5 : i64 + %7 = arith.extsi %3 : i32 to i64 + %8 = arith.extsi %arg7 : i32 to i64 + %9 = arith.muli %7, %8 : i64 + %10 = arith.addi %6, %9 : i64 + %11 = tt.addptr %arg0, %10 : !tt.ptr, i64 + %12 = arith.muli %0, %c64_i32 : i32 + %13 = arith.extsi %arg20 : i32 to i64 + %14 = arith.extsi %arg8 : i32 to i64 + %15 = tt.make_tensor_ptr %11, [%13, %c64_i64], [%14, %c1_i64], [%12, %c0_i32] {order = array} : >> + %16 = tt.addptr %arg2, %10 : !tt.ptr, i64 + %17 = arith.extsi %arg14 : i32 to i64 + // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : >> + %18 = tt.make_tensor_ptr %16, [%13, %c64_i64], [%c1_i64, %17], [%c0_i32, %c0_i32] {order = array} : >> + %19 = tt.addptr %arg1, %10 : !tt.ptr, i64 + %20 = arith.extsi %arg11 : i32 to i64 + // CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr {{.*}} : + %21 = tt.make_tensor_ptr %19, [%c64_i64, %13], [%c1_i64, %20], [%c0_i32, %c0_i32] {order = array} : >> + %32 = tt.load %15 : !tt.ptr>> + // CHECK: scf.for {{.*}} iter_args([[ARG1:%.*]] = [[PTR2]], [[ARG2:%.*]] = [[PTR1]]) -> (!tt.ptr>, !tt.ptr>>) + %35:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg25 = %21, %arg26 = %18) -> (!tt.ptr>>, !tt.ptr>>) : i32 { + // CHECK: [[LOAD:%.*]] = tt.load [[ARG1]] : !tt.ptr> + // CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]> -> tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %58 = tt.load %arg25 : !tt.ptr>> + %59 = tt.fp_to_fp %32 : tensor<64x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %60 = tt.fp_to_fp %58 : tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %61 = tt.dot %59, %60, %cst_2, inputPrecision = tf32 : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> + // CHECK-DAG: [[ADVANCE1:%.*]] = tt.advance [[ARG1]], {{.*}} : > + // CHECK-DAG: [[ADVANCE2:%.*]] = tt.advance [[ARG2]], {{.*}} : >> + // CHECK-NEXT: scf.yield [[ADVANCE1]], [[ADVANCE2]] : !tt.ptr>, !tt.ptr>> + %84 = tt.advance %arg26, [%c32_i32, %c0_i32] : >> + %85 = tt.advance %arg25, [%c0_i32, %c32_i32] : >> + scf.yield %85, %84 : !tt.ptr>>, !tt.ptr>> + } + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index 05f10e2efc..4a05523325 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -30,6 +30,7 @@ namespace { struct CoalescePass : public ttgi::impl::TritonIntelGPUCoalesceBase { +private: void setCoalescedEncoding(tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, int numWarps, int threadsPerWarp, @@ -180,7 +181,29 @@ struct CoalescePass if (auto yieldOp = dyn_cast(user)) { // Modify and propagate the result of the enclosing loop. auto forOp = yieldOp->getParentOfType(); - changeAndPropagateLayout(forOp, layout, rewriter); + + rewriter.modifyOpInPlace(forOp, [&]() { + for (auto [opType, res] : + llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) { + if (opType == res.getType()) + continue; + + assert(tt::isTensorPointerType(res.getType()) && + tt::isTensorPointerType(opType) && + "Expecting blocked pointers"); + assert(cast( + cast(opType).getPointeeType()) + .getEncoding() == layout && + "Unexpected layout"); + + auto resType = cast(res.getType()); + auto tensorType = cast(resType.getPointeeType()); + res.setType(tt::PointerType::get(getNewType(tensorType, layout), + resType.getAddressSpace())); + } + }); + + propagateLayout(forOp, layout, rewriter); continue; } @@ -204,7 +227,29 @@ struct CoalescePass if (auto yieldOp = dyn_cast(user)) { // Modify and propagate the result of the enclosing loop. auto forOp = yieldOp->getParentOfType(); - changeAndPropagateLayout(forOp, layout, rewriter); + + rewriter.modifyOpInPlace(forOp, [&]() { + for (auto [opType, res] : + llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) { + if (opType == res.getType()) + continue; + + assert(tt::isTensorPointerType(res.getType()) && + tt::isTensorPointerType(opType) && + "Expecting blocked pointers"); + assert(cast( + cast(opType).getPointeeType()) + .getEncoding() == layout && + "Unexpected layout"); + + auto resType = cast(res.getType()); + auto tensorType = cast(resType.getPointeeType()); + res.setType(tt::PointerType::get(getNewType(tensorType, layout), + resType.getAddressSpace())); + } + }); + + propagateLayout(forOp, layout, rewriter); continue; } @@ -248,6 +293,10 @@ struct CoalescePass if (!tt::isTensorPointerType(res.getType())) continue; + // Problem: if the operation is a for loop we cannot modify the layout + // of all the tensor ptr results, we need to modify only the one used by + // the yield operation. + auto ptrType = cast(res.getType()); auto tensorType = cast(ptrType.getPointeeType()); res.setType(tt::PointerType::get(getNewType(tensorType, layout), @@ -319,6 +368,7 @@ struct CoalescePass assert(succeeded(verify(newOp)) && "Operation verification failed"); } +public: void runOnOperation() override { // Run axis info analysis ModuleOp moduleOp = getOperation(); @@ -340,6 +390,13 @@ struct CoalescePass if (!refTensorType || !refTensorType.getEncoding()) return; + // static int n = 0; + // if (tt::isTensorPointerType(ptr.getType())) + // n++; + + // if (n != 2) + // return; + int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp); int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, From 25466652e10d43686215d50a13176e73c4a4ec39 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 17 Oct 2024 20:16:43 +0000 Subject: [PATCH 05/32] Fix pre_commit Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/coalesce.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/TritonIntelGPU/coalesce.mlir b/test/TritonIntelGPU/coalesce.mlir index 8dfe6d7ebe..d9b2de454c 100644 --- a/test/TritonIntelGPU/coalesce.mlir +++ b/test/TritonIntelGPU/coalesce.mlir @@ -326,8 +326,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %59 = tt.fp_to_fp %32 : tensor<64x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %60 = tt.fp_to_fp %58 : tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %61 = tt.dot %59, %60, %cst_2, inputPrecision = tf32 : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> - // CHECK-DAG: [[ADVANCE1:%.*]] = tt.advance [[ARG1]], {{.*}} : > - // CHECK-DAG: [[ADVANCE2:%.*]] = tt.advance [[ARG2]], {{.*}} : >> + // CHECK-DAG: [[ADVANCE1:%.*]] = tt.advance [[ARG1]], {{.*}} : > + // CHECK-DAG: [[ADVANCE2:%.*]] = tt.advance [[ARG2]], {{.*}} : >> // CHECK-NEXT: scf.yield [[ADVANCE1]], [[ADVANCE2]] : !tt.ptr>, !tt.ptr>> %84 = tt.advance %arg26, [%c32_i32, %c0_i32] : >> %85 = tt.advance %arg25, [%c0_i32, %c32_i32] : >> From 4d5dc49ff8cde915fa0593e89412f094a3ef95af Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 17 Oct 2024 20:56:10 +0000 Subject: [PATCH 06/32] Reenable rewrite tensor ptr Signed-off-by: Tiotto, Ettore --- third_party/intel/backend/compiler.py | 2 +- .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 34 ++++++++----------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index a34111cfa6..b1ce1845f6 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -235,7 +235,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) - # intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) + intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) intel.passes.ttgpuir.add_coalesce(pm) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index 4a05523325..554bb6b81d 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -28,6 +28,13 @@ namespace ttgi = mlir::triton::gpu::intel; namespace { +RankedTensorType getRankedTensorType(Type ptrTy) { + return tt::isTensorPointerType(ptrTy) + ? cast( + cast(ptrTy).getPointeeType()) + : dyn_cast(ptrTy); +} + struct CoalescePass : public ttgi::impl::TritonIntelGPUCoalesceBase { private: @@ -49,12 +56,7 @@ struct CoalescePass SmallVector order = argSort(contiguity); LDBG("order=[" << triton::join(order, ", ") << "]"); - RankedTensorType refTensorType = - tt::isTensorPointerType(ptr.getType()) - ? cast( - cast(ptr.getType()).getPointeeType()) - : cast(ptr.getType()); - + RankedTensorType refTensorType = getRankedTensorType(ptr.getType()); auto matchesShape = [&refTensorType](const Value &val) { auto rttType = dyn_cast(val.getType()); return rttType && rttType.getShape() == refTensorType.getShape(); @@ -197,7 +199,7 @@ struct CoalescePass "Unexpected layout"); auto resType = cast(res.getType()); - auto tensorType = cast(resType.getPointeeType()); + RankedTensorType tensorType = getRankedTensorType(resType); res.setType(tt::PointerType::get(getNewType(tensorType, layout), resType.getAddressSpace())); } @@ -243,7 +245,7 @@ struct CoalescePass "Unexpected layout"); auto resType = cast(res.getType()); - auto tensorType = cast(resType.getPointeeType()); + RankedTensorType tensorType = getRankedTensorType(resType); res.setType(tt::PointerType::get(getNewType(tensorType, layout), resType.getAddressSpace())); } @@ -281,8 +283,10 @@ struct CoalescePass } } - // Change the \p layout of the \p op result(s) and propagate the new result - // type to its users. + // TODO: change the implementation to handle only operation yielding one + // result? + // Change the \p layout of the \p op result(s) and propagate the new + // result type to its users. static void changeAndPropagateLayout(Operation *op, Attribute layout, IRRewriter &rewriter) { assert(op && op->getNumResults() != 0 && @@ -293,10 +297,6 @@ struct CoalescePass if (!tt::isTensorPointerType(res.getType())) continue; - // Problem: if the operation is a for loop we cannot modify the layout - // of all the tensor ptr results, we need to modify only the one used by - // the yield operation. - auto ptrType = cast(res.getType()); auto tensorType = cast(ptrType.getPointeeType()); res.setType(tt::PointerType::get(getNewType(tensorType, layout), @@ -382,11 +382,7 @@ struct CoalescePass if (!ptr) return; - RankedTensorType refTensorType = - tt::isTensorPointerType(ptr.getType()) - ? cast( - cast(ptr.getType()).getPointeeType()) - : dyn_cast(ptr.getType()); + RankedTensorType refTensorType = getRankedTensorType(ptr.getType()); if (!refTensorType || !refTensorType.getEncoding()) return; From c3fdbba82dfafbd00fd80ed94ed95ff1e4173dc8 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 18 Oct 2024 13:13:49 +0000 Subject: [PATCH 07/32] Fix test_core regression Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/Analysis/AxisInfo.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 70eff83dc1..11bc600e3c 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -1010,8 +1010,12 @@ class MakeTensorPtrOpAxisInfoVisitor final getAxisInfo(triton::MakeTensorPtrOp op, ArrayRef *> operands) override { LDBG("MakeTensorPtrOpAxisInfoVisitor: " << *op); - assert(op.getShape().size() == 2 && operands.size() == 7 && - "MakeTensorPtrOp should have 2D shape"); + + // TODO: Extend to higher dimension tensor pointers. + if (op.getShape().size() != 2) + return {}; + + assert(operands.size() == 7 && "MakeTensorPtrOp should have 2D shape"); AxisInfo ptrInfo = operands[0]->getValue(); AxisInfo shapeInfo0 = operands[1]->getValue(); From d9de8e772d894fdbdf4d4e2819de05f1aee9323e Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 18 Oct 2024 16:25:35 +0000 Subject: [PATCH 08/32] Fix tutorial assertion Signed-off-by: Tiotto, Ettore --- .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 69 ++++++++++++------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index 554bb6b81d..fd86573bda 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -173,8 +173,15 @@ struct CoalescePass static void propagateLayout(BlockArgument arg, Attribute layout, IRRewriter &rewriter) { llvm::errs() << "arg: " << arg << "\n"; - for (Operation *user : arg.getUsers()) { - llvm::errs() << "user: " << *user << "\n\n"; + + auto users = arg.getUsers(); + if (users.empty()) { + llvm::errs() << "arg has no users\n"; + return; + } + + for (Operation *user : users) { + llvm::errs() << "arg's user: " << *user << "\n\n"; if (filterUser(user)) { llvm::errs() << "SKIP\n"; continue; @@ -218,9 +225,15 @@ struct CoalescePass assert(root && root->getNumResults() != 0 && "Expecting an operation yielding a result"); - // llvm::errs() << "root: " << *root << "\n\n"; - for (Operation *user : root->getUsers()) { - llvm::errs() << "user: " << *user << "\n\n"; + llvm::errs() << "root: " << *root << "\n"; + auto users = root->getUsers(); + if (users.empty()) { + llvm::errs() << "root has no users\n"; + return; + } + + for (Operation *user : users) { + llvm::errs() << "root's user: " << *user << "\n\n"; if (filterUser(user)) { llvm::errs() << "SKIP\n"; continue; @@ -262,7 +275,6 @@ struct CoalescePass if (res == loopArg && tt::isTensorPointerType(res.getType())) { llvm::errs() << "arg: " << arg << "\n"; llvm::errs() << "loopArg: " << loopArg << "\n"; - llvm::errs() << "arg type: " << arg.getType() << "\n"; // Modify the layout of the loop init argument... tt::PointerType ptrType = cast(arg.getType()); @@ -309,7 +321,7 @@ struct CoalescePass } void coalesceOp(Attribute encoding, Operation *op) { - llvm::errs() << "Coalescing op: " << *op << "\n"; + LDBG("Coalescing op: " << *op); OpBuilder builder(op); IRRewriter rewriter(builder); @@ -362,9 +374,11 @@ struct CoalescePass } op->getResult(i).replaceAllUsesWith(newResult); } + + LDBG("Old op: " << *op); + LDBG("newOp: " << *newOp); op->erase(); - llvm::errs() << "newOp: " << *newOp << "\n"; assert(succeeded(verify(newOp)) && "Operation verification failed"); } @@ -399,12 +413,15 @@ struct CoalescePass layoutMap); }); - llvm::errs() << "layoutMap:\n"; - for (auto [op, encoding] : layoutMap) { - llvm::errs() << "op: " << *op << "\n"; - llvm::errs() << "encoding: " << encoding << "\n"; - } - llvm::errs() << "\n"; + LLVM_DEBUG({ + DBGS() << "layoutMap:" + << "\n"; + for (auto [op, encoding] : layoutMap) { + DBGS() << "op: " << *op << "\n"; + DBGS() << "encoding: " << encoding << "\n"; + } + llvm::errs() << "\n\n"; + }); // For each memory op that has a layout L1: // 1. Create a coalesced memory layout L2 of the pointer operands @@ -415,22 +432,22 @@ struct CoalescePass // 5. Replace all the uses of the original memory op by the new one for (auto [op, layout] : layoutMap) { coalesceOp(layout, op); - if (failed(verify(moduleOp))) { - for (Operation &op1 : moduleOp.getOps()) { - if (isa(op1)) { - for (Operation &op2 : cast(op1).getOps()) { - if (failed(verify(&op2))) { - llvm::errs() << "op2: " << op2 << "\n"; - llvm::errs() << "Operation verification failed.\n"; - } + } + + if (failed(verify(moduleOp))) { + llvm::errs() << "Module verification failed.\n"; + llvm::errs() << "mod: " << moduleOp << "\n"; + for (Operation &op1 : moduleOp.getOps()) { + if (isa(op1)) { + for (Operation &op2 : cast(op1).getOps()) { + if (failed(verify(&op2))) { + llvm::errs() << "op2: " << op2 << "\n"; + llvm::errs() << "Operation verification failed.\n"; + assert(false); } } } - llvm::errs() << "Module verification failed.\n"; - llvm::errs() << "mod: " << moduleOp << "\n"; - assert(false); } - llvm::errs() << "Module verified.\n"; } } }; From 949256e42c6297129edb542ba74c3c87f49401a7 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 18 Oct 2024 19:51:47 +0000 Subject: [PATCH 09/32] Refactor Signed-off-by: Tiotto, Ettore --- .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 210 ++++++++---------- 1 file changed, 87 insertions(+), 123 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index fd86573bda..b10ee7b2a2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -2,6 +2,8 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -11,6 +13,7 @@ #include "triton/Tools/StrUtil.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "tritonintelgpu-coalesce" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -150,7 +153,7 @@ struct CoalescePass static bool filterUser(Operation *op) { // Yield operations trigger updating the layout of the containing loop - // results, so don't skip them. + // results, don't skip them. if (isa(op)) return false; @@ -168,154 +171,123 @@ struct CoalescePass return false; } - // Propagate the \p root block argument operation output layout along the - // def-use chain. - static void propagateLayout(BlockArgument arg, Attribute layout, - IRRewriter &rewriter) { - llvm::errs() << "arg: " << arg << "\n"; + // Propagate the layout to \p root operation's result to the \p forOp loop + // init argument that uses it, and transitively to the operations in the loop + // body that use that argument. + static void propagate(scf::ForOp forOp, Operation *root, Attribute layout, + IRRewriter &rewriter) { + assert(llvm::any_of(root->getUsers(), + [&](Operation *user) { return user == forOp; }) && + "Expecting the loop to be a user of the root operation"); + + for (BlockArgument arg : forOp.getRegionIterArgs()) { + Value loopArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; + for (OpResult res : root->getResults()) { + if (res != loopArg || !tt::isTensorPointerType(res.getType())) + continue; - auto users = arg.getUsers(); - if (users.empty()) { - llvm::errs() << "arg has no users\n"; - return; + LDBG("loopArg: " << loopArg); + + // Modify the layout of the loop init argument... + tt::PointerType ptrType = cast(arg.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + arg.setType(tt::PointerType::get(getNewType(tensorType, layout), + ptrType.getAddressSpace())); + + // ... and then propagate it to the operations in the loop. + propagateLayout(arg, layout, rewriter); + } } + } + + // Modify the given loop \p forOp and propagate the result of the enclosing + // loop. + static void propagate(scf::ForOp forOp, Attribute layout, + IRRewriter &rewriter) { + Operation *yieldOp = forOp.getBody()->getTerminator(); + + rewriter.modifyOpInPlace(forOp, [&]() { + for (auto [opType, res] : + llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) { + if (opType == res.getType()) + continue; + + assert(tt::isTensorPointerType(res.getType()) && + tt::isTensorPointerType(opType) && "Expecting blocked pointers"); + assert(cast( + cast(opType).getPointeeType()) + .getEncoding() == layout && + "Unexpected layout"); - for (Operation *user : users) { - llvm::errs() << "arg's user: " << *user << "\n\n"; + auto resType = cast(res.getType()); + RankedTensorType tensorType = getRankedTensorType(resType); + res.setType(tt::PointerType::get(getNewType(tensorType, layout), + resType.getAddressSpace())); + } + }); + + propagateLayout(forOp, layout, rewriter); + } + + static void propagateLayout(BlockArgument arg, Attribute layout, + IRRewriter &rewriter) { + LDBG("arg: " << arg); + for (Operation *user : arg.getUsers()) { + LDBG("arg's user: " << *user << "\n"); if (filterUser(user)) { - llvm::errs() << "SKIP\n"; continue; } - if (auto yieldOp = dyn_cast(user)) { - // Modify and propagate the result of the enclosing loop. auto forOp = yieldOp->getParentOfType(); - - rewriter.modifyOpInPlace(forOp, [&]() { - for (auto [opType, res] : - llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) { - if (opType == res.getType()) - continue; - - assert(tt::isTensorPointerType(res.getType()) && - tt::isTensorPointerType(opType) && - "Expecting blocked pointers"); - assert(cast( - cast(opType).getPointeeType()) - .getEncoding() == layout && - "Unexpected layout"); - - auto resType = cast(res.getType()); - RankedTensorType tensorType = getRankedTensorType(resType); - res.setType(tt::PointerType::get(getNewType(tensorType, layout), - resType.getAddressSpace())); - } - }); - - propagateLayout(forOp, layout, rewriter); + propagate(forOp, layout, rewriter); continue; } - changeAndPropagateLayout(user, layout, rewriter); } } static void propagateLayout(Operation *root, Attribute layout, IRRewriter &rewriter) { - assert(root && root->getNumResults() != 0 && + assert(root->getNumResults() != 0 && "Expecting an operation yielding a result"); - llvm::errs() << "root: " << *root << "\n"; - auto users = root->getUsers(); - if (users.empty()) { - llvm::errs() << "root has no users\n"; - return; - } - - for (Operation *user : users) { - llvm::errs() << "root's user: " << *user << "\n\n"; + LDBG("root: " << *root); + for (Operation *user : root->getUsers()) { + LDBG("root's user: " << *user << "\n"); if (filterUser(user)) { - llvm::errs() << "SKIP\n"; continue; } - - if (auto yieldOp = dyn_cast(user)) { - // Modify and propagate the result of the enclosing loop. - auto forOp = yieldOp->getParentOfType(); - - rewriter.modifyOpInPlace(forOp, [&]() { - for (auto [opType, res] : - llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) { - if (opType == res.getType()) - continue; - - assert(tt::isTensorPointerType(res.getType()) && - tt::isTensorPointerType(opType) && - "Expecting blocked pointers"); - assert(cast( - cast(opType).getPointeeType()) - .getEncoding() == layout && - "Unexpected layout"); - - auto resType = cast(res.getType()); - RankedTensorType tensorType = getRankedTensorType(resType); - res.setType(tt::PointerType::get(getNewType(tensorType, layout), - resType.getAddressSpace())); - } - }); - - propagateLayout(forOp, layout, rewriter); + if (auto forOp = dyn_cast(user)) { + propagate(forOp, root, layout, rewriter); continue; } - - if (auto forOp = dyn_cast(user)) { - for (BlockArgument arg : forOp.getRegionIterArgs()) { - Value loopArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; - for (OpResult res : root->getResults()) { - if (res == loopArg && tt::isTensorPointerType(res.getType())) { - llvm::errs() << "arg: " << arg << "\n"; - llvm::errs() << "loopArg: " << loopArg << "\n"; - - // Modify the layout of the loop init argument... - tt::PointerType ptrType = cast(arg.getType()); - auto tensorType = - cast(ptrType.getPointeeType()); - arg.setType(tt::PointerType::get(getNewType(tensorType, layout), - ptrType.getAddressSpace())); - - // ... and then propagate it to the operations in the loop. - propagateLayout(arg, layout, rewriter); - } - } - } + if (auto yieldOp = dyn_cast(user)) { + auto forOp = yieldOp->getParentOfType(); + propagate(forOp, layout, rewriter); continue; } - changeAndPropagateLayout(user, layout, rewriter); } } - // TODO: change the implementation to handle only operation yielding one - // result? - // Change the \p layout of the \p op result(s) and propagate the new - // result type to its users. + // Change the \p layout of the \p op result and propagate the new result type + // to its users. static void changeAndPropagateLayout(Operation *op, Attribute layout, IRRewriter &rewriter) { - assert(op && op->getNumResults() != 0 && + assert(op && op->getNumResults() == 1 && "Expecting operation yielding a result"); rewriter.modifyOpInPlace(op, [&]() { - for (Value res : op->getResults()) { - if (!tt::isTensorPointerType(res.getType())) - continue; - - auto ptrType = cast(res.getType()); - auto tensorType = cast(ptrType.getPointeeType()); - res.setType(tt::PointerType::get(getNewType(tensorType, layout), - ptrType.getAddressSpace())); - } + Value res = op->getOpResult(0); + assert(tt::isTensorPointerType(res.getType()) && + "Expecting a block pointer"); + + auto ptrType = cast(res.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + res.setType(tt::PointerType::get(getNewType(tensorType, layout), + ptrType.getAddressSpace())); }); - llvm::errs() << "Coalesced op: " << *op << "\n"; + LDBG("Coalesced op: " << *op); propagateLayout(op, layout, rewriter); } @@ -400,13 +372,6 @@ struct CoalescePass if (!refTensorType || !refTensorType.getEncoding()) return; - // static int n = 0; - // if (tt::isTensorPointerType(ptr.getType())) - // n++; - - // if (n != 2) - // return; - int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp); int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, @@ -414,8 +379,7 @@ struct CoalescePass }); LLVM_DEBUG({ - DBGS() << "layoutMap:" - << "\n"; + DBGS() << "layoutMap:" << "\n"; for (auto [op, encoding] : layoutMap) { DBGS() << "op: " << *op << "\n"; DBGS() << "encoding: " << encoding << "\n"; From 754ec703cf73cfc097e591359e9aa0617ec6fac1 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 18 Oct 2024 20:24:02 +0000 Subject: [PATCH 10/32] Cleanup Signed-off-by: Tiotto, Ettore --- .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 166 +++++++++--------- 1 file changed, 79 insertions(+), 87 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index b10ee7b2a2..a69dfec2d0 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -171,11 +171,77 @@ struct CoalescePass return false; } - // Propagate the layout to \p root operation's result to the \p forOp loop + // Change the \p layout of the \p op result and propagate the new result type + // to its users. + void changeAndPropagateLayout(Operation *op, Attribute layout, + IRRewriter &rewriter) const { + assert(op && op->getNumResults() == 1 && + "Expecting operation yielding a result"); + + rewriter.modifyOpInPlace(op, [&]() { + Value res = op->getOpResult(0); + assert(tt::isTensorPointerType(res.getType()) && + "Expecting a block pointer"); + + auto ptrType = cast(res.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + res.setType(tt::PointerType::get(getNewType(tensorType, layout), + ptrType.getAddressSpace())); + }); + LDBG("Coalesced op: " << *op); + + propagateLayout(op, layout, rewriter); + } + + // Propagate the layout of the \p root operation's result to its users. + void propagateLayout(Operation *root, Attribute layout, + IRRewriter &rewriter) const { + assert(root->getNumResults() != 0 && + "Expecting an operation yielding a result"); + + LDBG("root: " << *root); + for (Operation *user : root->getUsers()) { + if (filterUser(user)) + continue; + + LDBG("root's user: " << *user << "\n"); + if (auto forOp = dyn_cast(user)) { + propagateLayoutToArgsAndBody(forOp, root, layout, rewriter); + continue; + } + if (auto yieldOp = dyn_cast(user)) { + auto forOp = yieldOp->getParentOfType(); + propagateLayoutToLoopResults(forOp, layout, rewriter); + continue; + } + changeAndPropagateLayout(user, layout, rewriter); + } + } + + // Propagate the layout of the \p arg block argument to its users. + void propagateLayout(BlockArgument arg, Attribute layout, + IRRewriter &rewriter) const { + LDBG("arg: " << arg); + for (Operation *user : arg.getUsers()) { + if (filterUser(user)) + continue; + + LDBG("arg's user: " << *user << "\n"); + if (auto yieldOp = dyn_cast(user)) { + auto forOp = yieldOp->getParentOfType(); + propagateLayoutToLoopResults(forOp, layout, rewriter); + continue; + } + changeAndPropagateLayout(user, layout, rewriter); + } + } + + // Propagate the layout of the \p root operation's result to the \p forOp loop // init argument that uses it, and transitively to the operations in the loop // body that use that argument. - static void propagate(scf::ForOp forOp, Operation *root, Attribute layout, - IRRewriter &rewriter) { + void propagateLayoutToArgsAndBody(scf::ForOp forOp, Operation *root, + Attribute layout, + IRRewriter &rewriter) const { assert(llvm::any_of(root->getUsers(), [&](Operation *user) { return user == forOp; }) && "Expecting the loop to be a user of the root operation"); @@ -202,8 +268,8 @@ struct CoalescePass // Modify the given loop \p forOp and propagate the result of the enclosing // loop. - static void propagate(scf::ForOp forOp, Attribute layout, - IRRewriter &rewriter) { + void propagateLayoutToLoopResults(scf::ForOp forOp, Attribute layout, + IRRewriter &rewriter) const { Operation *yieldOp = forOp.getBody()->getTerminator(); rewriter.modifyOpInPlace(forOp, [&]() { @@ -229,69 +295,6 @@ struct CoalescePass propagateLayout(forOp, layout, rewriter); } - static void propagateLayout(BlockArgument arg, Attribute layout, - IRRewriter &rewriter) { - LDBG("arg: " << arg); - for (Operation *user : arg.getUsers()) { - LDBG("arg's user: " << *user << "\n"); - if (filterUser(user)) { - continue; - } - if (auto yieldOp = dyn_cast(user)) { - auto forOp = yieldOp->getParentOfType(); - propagate(forOp, layout, rewriter); - continue; - } - changeAndPropagateLayout(user, layout, rewriter); - } - } - - static void propagateLayout(Operation *root, Attribute layout, - IRRewriter &rewriter) { - assert(root->getNumResults() != 0 && - "Expecting an operation yielding a result"); - - LDBG("root: " << *root); - for (Operation *user : root->getUsers()) { - LDBG("root's user: " << *user << "\n"); - if (filterUser(user)) { - continue; - } - if (auto forOp = dyn_cast(user)) { - propagate(forOp, root, layout, rewriter); - continue; - } - if (auto yieldOp = dyn_cast(user)) { - auto forOp = yieldOp->getParentOfType(); - propagate(forOp, layout, rewriter); - continue; - } - changeAndPropagateLayout(user, layout, rewriter); - } - } - - // Change the \p layout of the \p op result and propagate the new result type - // to its users. - static void changeAndPropagateLayout(Operation *op, Attribute layout, - IRRewriter &rewriter) { - assert(op && op->getNumResults() == 1 && - "Expecting operation yielding a result"); - - rewriter.modifyOpInPlace(op, [&]() { - Value res = op->getOpResult(0); - assert(tt::isTensorPointerType(res.getType()) && - "Expecting a block pointer"); - - auto ptrType = cast(res.getType()); - auto tensorType = cast(ptrType.getPointeeType()); - res.setType(tt::PointerType::get(getNewType(tensorType, layout), - ptrType.getAddressSpace())); - }); - LDBG("Coalesced op: " << *op); - - propagateLayout(op, layout, rewriter); - } - void coalesceOp(Attribute encoding, Operation *op) { LDBG("Coalescing op: " << *op); @@ -316,8 +319,7 @@ struct CoalescePass "Expecting operand to have blocked pointer type"); auto defOp = findDefiningMakeTensorPtrOp(operand); assert(defOp && "Expected a make_tensor_ptr operation"); - - llvm::errs() << "Found make_tensor_ptr definition: " << *defOp << "\n"; + LDBG("Found make_tensor_ptr definition: " << *defOp); changeAndPropagateLayout(*defOp, encoding, rewriter); newArgs.push_back(operand); } @@ -326,8 +328,7 @@ struct CoalescePass // Convert output types SmallVector newTypes; for (auto t : op->getResultTypes()) { - bool isAsync = isa(op); - assert(!isAsync && + assert(!isa(op) && "AsyncCopyGlobalToLocalOp not supported for Intel GPU"); newTypes.push_back(getNewType(cast(t), encoding)); } @@ -379,7 +380,8 @@ struct CoalescePass }); LLVM_DEBUG({ - DBGS() << "layoutMap:" << "\n"; + DBGS() << "layoutMap:" + << "\n"; for (auto [op, encoding] : layoutMap) { DBGS() << "op: " << *op << "\n"; DBGS() << "encoding: " << encoding << "\n"; @@ -398,20 +400,10 @@ struct CoalescePass coalesceOp(layout, op); } - if (failed(verify(moduleOp))) { - llvm::errs() << "Module verification failed.\n"; - llvm::errs() << "mod: " << moduleOp << "\n"; - for (Operation &op1 : moduleOp.getOps()) { - if (isa(op1)) { - for (Operation &op2 : cast(op1).getOps()) { - if (failed(verify(&op2))) { - llvm::errs() << "op2: " << op2 << "\n"; - llvm::errs() << "Operation verification failed.\n"; - assert(false); - } - } - } - } + // Verify the module's functions after the transformation. + for (auto op : moduleOp.getOps()) { + for (Operation &op1 : op.getOps()) + assert(succeeded(verify(&op1))); } } }; From 469407b94240af90d49c8bf7a0b19652a2cb1e02 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 18 Oct 2024 20:42:20 +0000 Subject: [PATCH 11/32] Cleanup Signed-off-by: Tiotto, Ettore --- lib/Analysis/CMakeLists.txt | 1 - lib/Analysis/intel/TestAxisInfo.cpp | 47 ----------------------------- test/lib/Analysis/CMakeLists.txt | 1 + 3 files changed, 1 insertion(+), 48 deletions(-) delete mode 100644 lib/Analysis/intel/TestAxisInfo.cpp diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index 60a7dd0b43..a84f0649b6 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -1,5 +1,4 @@ add_triton_library(TritonAnalysis - intel/TestAxisInfo.cpp AxisInfo.cpp Allocation.cpp Membar.cpp diff --git a/lib/Analysis/intel/TestAxisInfo.cpp b/lib/Analysis/intel/TestAxisInfo.cpp deleted file mode 100644 index 5317a492fc..0000000000 --- a/lib/Analysis/intel/TestAxisInfo.cpp +++ /dev/null @@ -1,47 +0,0 @@ -#include "intel/include/Analysis/AxisInfo.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; -using namespace mlir::triton::intel; - -namespace { - -struct TestAxisInfoPass - : public PassWrapper> { - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass); - - StringRef getArgument() const final { return "test-print-axis-info"; } - StringRef getDescription() const final { - return "print the result of the alignment analysis pass"; - } - - void runOnOperation() override { - Operation *operation = getOperation(); - ModuleOp moduleOp = cast(operation); - ModuleAxisInfoAnalysis moduleAxisInfoAnalysis(moduleOp); - moduleOp.walk([&](triton::FuncOp funcOp) { - auto &os = llvm::errs(); - auto opName = SymbolTable::getSymbolName(funcOp).getValue().str(); - os << "@" << opName << "\n"; - funcOp.walk([&](Operation *op) { - if (op->getNumResults() < 1) - return; - for (Value result : op->getResults()) { - result.print(os); - os << " => "; - auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result); - if (axisInfo) - axisInfo->print(os); - os << "\n"; - } - }); - }); - } -}; - -} // namespace - -namespace mlir::test::intel { -void registerTestAxisInfoPass() { PassRegistration(); } -} // namespace mlir::test::intel diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt index da7bc3f78a..75f785ce24 100644 --- a/test/lib/Analysis/CMakeLists.txt +++ b/test/lib/Analysis/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(TritonTestAnalysis + intel/TestAxisInfo.cpp TestAlias.cpp TestAxisInfo.cpp TestAllocation.cpp From 9f4f98d3be6ba290d04b68b39ee0e30595554b90 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 21 Oct 2024 21:23:42 +0000 Subject: [PATCH 12/32] Extend axis info analysis to more block ptrs Signed-off-by: Tiotto, Ettore --- test/Analysis/intel/test-axis-info.mlir | 6 +- third_party/intel/lib/Analysis/AxisInfo.cpp | 59 ++++++++----------- .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 4 +- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/test/Analysis/intel/test-axis-info.mlir b/test/Analysis/intel/test-axis-info.mlir index 1a3805f018..39dcd0bd3e 100644 --- a/test/Analysis/intel/test-axis-info.mlir +++ b/test/Analysis/intel/test-axis-info.mlir @@ -885,9 +885,11 @@ tt.func public @make_tensor_ptr(%arg0: !tt.ptr, %arg1: !tt.ptr {tt. %c1_i64 = arith.constant 1 : i64 %c32_i64 = arith.constant 32 : i64 %c128_i64 = arith.constant 128 : i64 - // CHECK: %0 = tt.make_tensor_ptr %arg0, {{.*}} => contiguity = [128, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = + // CHECK: tt.make_tensor_ptr %arg0, {{.*}} => contiguity = [128, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> - // CHECK: %1 = tt.make_tensor_ptr %arg1, {{.*}} => contiguity = [32, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = + // CHECK: tt.make_tensor_ptr %arg1, {{.*}} => contiguity = [64, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = %1 = tt.make_tensor_ptr %arg1, [%c32_i64, %c32_i64], [%c1_i64, %arg2], [%c0_i32, %c0_i32] {order = array} : > + // CHECK: tt.make_tensor_ptr %arg1, {{.*}} => contiguity = [32, 64], divisibility = [1, 1], constancy = [1, 1], constant_value = + %2 = tt.make_tensor_ptr %arg1, [%arg2, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > tt.return } diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 11bc600e3c..8da0c3284b 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -4,6 +4,7 @@ #include "llvm/Support/raw_ostream.h" #include "intel/include/Analysis/AxisInfo.h" +#include "mlir/IR/BuiltinTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #define DEBUG_TYPE "intel-axis-info" @@ -1011,45 +1012,35 @@ class MakeTensorPtrOpAxisInfoVisitor final ArrayRef *> operands) override { LDBG("MakeTensorPtrOpAxisInfoVisitor: " << *op); - // TODO: Extend to higher dimension tensor pointers. - if (op.getShape().size() != 2) - return {}; + auto ptrTy = cast(op.getResult().getType()); + auto tensorType = cast(ptrTy.getPointeeType()); + ArrayRef blkShape = tensorType.getShape(); + unsigned rank = op.getShape().size(); - assert(operands.size() == 7 && "MakeTensorPtrOp should have 2D shape"); + // TODO: Support higher rank tensors. + if (rank > 2) + return AxisInfo(); + + SmallVector strideInfo; + for (int i = rank + 1; i <= rank * 2; ++i) + strideInfo.emplace_back(operands[i]->getValue()); AxisInfo ptrInfo = operands[0]->getValue(); - AxisInfo shapeInfo0 = operands[1]->getValue(); - AxisInfo shapeInfo1 = operands[2]->getValue(); - AxisInfo strideInfo0 = operands[3]->getValue(); - AxisInfo strideInfo1 = operands[4]->getValue(); - - std::optional shape0 = shapeInfo0.getConstantValue(); - std::optional shape1 = shapeInfo1.getConstantValue(); - std::optional stride0 = strideInfo0.getConstantValue(); - std::optional stride1 = strideInfo1.getConstantValue(); - - AxisInfo::DimVectorT contiguity{ - shape0.has_value() && (stride0 == 1) ? shape0.value() : 1, - shape1.has_value() && (stride1 == 1) ? shape1.value() : 1}; - - int64_t ptrDivisibility = ptrInfo.getDivisibility()[0]; - int64_t strideDivisibility0 = strideInfo0.getDivisibility()[0]; - int64_t strideDivisibility1 = strideInfo1.getDivisibility()[0]; - - LDBG("ptrDivisibility: " << ptrDivisibility); - LDBG("strideDivisibility0: " << strideDivisibility0); - LDBG("strideDivisibility1: " << strideDivisibility1); - - AxisInfo::DimVectorT divisibility{1, 1}; - if (ptrDivisibility > 1) { - if (contiguity[0] > 1) - divisibility[0] = std::min(ptrDivisibility, strideDivisibility1); - if (contiguity[1] > 1) - divisibility[1] = std::min(ptrDivisibility, strideDivisibility0); + int64_t ptrDivisibility = ptrInfo.getDivisibility(0); + LDBG("ptrDivisibility: "); + + AxisInfo::DimVectorT contiguity, constancy, divisibility; + for (int dim = 0; dim < rank; ++dim) { + contiguity.push_back( + strideInfo[dim].getConstantValue() == 1 ? blkShape[dim] : 1); + divisibility.push_back( + contiguity[dim] > 1 + ? std::min(ptrDivisibility, + strideInfo[dim == 0 ? 1 : 0].getDivisibility()[0]) + : 1); + constancy.push_back(1); } - AxisInfo::DimVectorT constancy{1, 1}; - return AxisInfo(contiguity, divisibility, constancy); } }; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index a69dfec2d0..854c76d9e2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -380,11 +380,11 @@ struct CoalescePass }); LLVM_DEBUG({ - DBGS() << "layoutMap:" + DBGS() << "\nlayoutMap:" << "\n"; for (auto [op, encoding] : layoutMap) { DBGS() << "op: " << *op << "\n"; - DBGS() << "encoding: " << encoding << "\n"; + DBGS() << "encoding: " << encoding << "\n\n"; } llvm::errs() << "\n\n"; }); From bb9b4c3a9a9f1670527f0416b081988bed10f7c7 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 22 Oct 2024 16:11:06 +0000 Subject: [PATCH 13/32] Address code review comments Signed-off-by: Tiotto, Ettore --- third_party/intel/backend/compiler.py | 2 +- third_party/intel/include/Analysis/AxisInfo.h | 1 - .../include/Dialect/TritonIntelGPU/IR/Utils.h | 1 - third_party/intel/lib/Analysis/AxisInfo.cpp | 6 ++--- .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 27 ++++++++----------- 5 files changed, 14 insertions(+), 23 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 86948112b9..7119377d11 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -235,7 +235,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) - intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) + # intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) intel.passes.ttgpuir.add_coalesce(pm) diff --git a/third_party/intel/include/Analysis/AxisInfo.h b/third_party/intel/include/Analysis/AxisInfo.h index 3016e02cad..1fbaba2e0c 100644 --- a/third_party/intel/include/Analysis/AxisInfo.h +++ b/third_party/intel/include/Analysis/AxisInfo.h @@ -11,7 +11,6 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include -#include namespace mlir::triton::intel { diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h index 7950fe1377..6357d4a8c2 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h @@ -46,7 +46,6 @@ inline unsigned getNumElementsPerThread( inline bool applyTransposedReduction() { return tools::getBoolEnv("TRITON_INTEL_REDUCE_TRANSPOSE"); } - } // namespace mlir::triton::gpu::intel #endif // TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index aa0805484a..f6b6b4d9c7 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -1,12 +1,10 @@ +#include "intel/include/Analysis/AxisInfo.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "intel/include/Analysis/AxisInfo.h" -#include "mlir/IR/BuiltinTypes.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - #define DEBUG_TYPE "intel-axis-info" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index 854c76d9e2..9a475a04bc 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -1,19 +1,18 @@ #include "intel/include/Analysis/AxisInfo.h" -#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +// #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "mlir/Support/LLVM.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" +// #include "triton/Dialect/Triton/IR/Dialect.h" +// #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/StrUtil.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include #define DEBUG_TYPE "tritonintelgpu-coalesce" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -104,8 +103,8 @@ struct CoalescePass perThread = std::min(perThread, std::max(numElems / numThreads, 1)); LDBG("perThread: " << perThread); - if (perThread <= 1) - return; + // if (perThread <= 1) + // return; if (!dyn_cast(op)) { // For ops that can result in a global memory write, we should enforce @@ -299,7 +298,6 @@ struct CoalescePass LDBG("Coalescing op: " << *op); OpBuilder builder(op); - IRRewriter rewriter(builder); // Convert operands // Note: for load/store with a blocked pointers argument we cannot change @@ -312,7 +310,7 @@ struct CoalescePass if (tensorType && !isa(tensorType.getEncoding())) { RankedTensorType newType = getNewType(tensorType, encoding); - newArgs.push_back(rewriter.create( + newArgs.push_back(builder.create( op->getLoc(), newType, operand)); } else { assert(isa(operand.getType()) && @@ -320,6 +318,7 @@ struct CoalescePass auto defOp = findDefiningMakeTensorPtrOp(operand); assert(defOp && "Expected a make_tensor_ptr operation"); LDBG("Found make_tensor_ptr definition: " << *defOp); + IRRewriter rewriter(builder); changeAndPropagateLayout(*defOp, encoding, rewriter); newArgs.push_back(operand); } @@ -335,14 +334,14 @@ struct CoalescePass // Construct new op with the new encoding. Operation *newOp = - rewriter.create(op->getLoc(), op->getName().getIdentifier(), newArgs, - newTypes, op->getAttrs()); + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); // Cast the results back to the original layout. for (size_t i = 0; i < op->getNumResults(); i++) { Value newResult = newOp->getResult(i); if (newTypes[i] != op->getResultTypes()[i]) { - newResult = rewriter.create( + newResult = builder.create( op->getLoc(), op->getResult(i).getType(), newResult); } op->getResult(i).replaceAllUsesWith(newResult); @@ -400,11 +399,7 @@ struct CoalescePass coalesceOp(layout, op); } - // Verify the module's functions after the transformation. - for (auto op : moduleOp.getOps()) { - for (Operation &op1 : op.getOps()) - assert(succeeded(verify(&op1))); - } + assert(succeeded(verify(moduleOp)) && "Module verification failed"); } }; From 8d9a158ea3227355e181afa01ae851e6500ed460 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 22 Oct 2024 16:12:53 +0000 Subject: [PATCH 14/32] Remove unrelated change Signed-off-by: Tiotto, Ettore --- third_party/intel/backend/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 7119377d11..86948112b9 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -235,7 +235,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) - # intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) + intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) intel.passes.ttgpuir.add_coalesce(pm) From 6529f043d853cd2ca678552efc38333a620c57d2 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 22 Oct 2024 16:13:46 +0000 Subject: [PATCH 15/32] Remove unrelated change Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index 9a475a04bc..f10595854e 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -1,13 +1,10 @@ #include "intel/include/Analysis/AxisInfo.h" -// #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "mlir/Support/LLVM.h" -// #include "triton/Dialect/Triton/IR/Dialect.h" -// #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/StrUtil.h" @@ -379,8 +376,7 @@ struct CoalescePass }); LLVM_DEBUG({ - DBGS() << "\nlayoutMap:" - << "\n"; + DBGS() << "\nlayoutMap:" << "\n"; for (auto [op, encoding] : layoutMap) { DBGS() << "op: " << *op << "\n"; DBGS() << "encoding: " << encoding << "\n\n"; From 0aa334b1bc7e48960ef98e96287ebfad9490dea8 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 22 Oct 2024 16:14:50 +0000 Subject: [PATCH 16/32] Remove unrelated change Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index f10595854e..4d1e758f19 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -100,9 +100,6 @@ struct CoalescePass perThread = std::min(perThread, std::max(numElems / numThreads, 1)); LDBG("perThread: " << perThread); - // if (perThread <= 1) - // return; - if (!dyn_cast(op)) { // For ops that can result in a global memory write, we should enforce // that each thread handles at most 128 bits, which is the widest From 547d6faed49fd892a83cb7abeb2b93330191d6ce Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 22 Oct 2024 19:19:57 +0000 Subject: [PATCH 17/32] Fix pre_commit Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index 4d1e758f19..15c4b43dcf 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -373,7 +373,8 @@ struct CoalescePass }); LLVM_DEBUG({ - DBGS() << "\nlayoutMap:" << "\n"; + DBGS() << "\nlayoutMap:" + << "\n"; for (auto [op, encoding] : layoutMap) { DBGS() << "op: " << *op << "\n"; DBGS() << "encoding: " << encoding << "\n\n"; From 2f97c1aeaf22da3001e4ac8330d4e40d5ca9290a Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 23 Oct 2024 16:11:40 +0000 Subject: [PATCH 18/32] Address code review comments Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/Analysis/AxisInfo.cpp | 3 --- .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 17 +++++------------ 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 2b052ed419..378ba01442 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -1,7 +1,5 @@ -#include "intel/include/Analysis/AxisInfo.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -1025,7 +1023,6 @@ class MakeTensorPtrOpAxisInfoVisitor final AxisInfo ptrInfo = operands[0]->getValue(); int64_t ptrDivisibility = ptrInfo.getDivisibility(0); - LDBG("ptrDivisibility: "); AxisInfo::DimVectorT contiguity, constancy, divisibility; for (int dim = 0; dim < rank; ++dim) { diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index d90a250ead..7753ca6707 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -1,6 +1,7 @@ #include "intel/include/Analysis/AxisInfo.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -27,13 +28,6 @@ namespace ttgi = mlir::triton::gpu::intel; namespace { -RankedTensorType getRankedTensorType(Type ptrTy) { - return tt::isTensorPointerType(ptrTy) - ? cast( - cast(ptrTy).getPointeeType()) - : dyn_cast(ptrTy); -} - struct CoalescePass : public ttgi::impl::TritonIntelGPUCoalesceBase { private: @@ -55,7 +49,7 @@ struct CoalescePass SmallVector order = argSort(contiguity); LDBG("order=[" << triton::join(order, ", ") << "]"); - RankedTensorType refTensorType = getRankedTensorType(ptr.getType()); + RankedTensorType refTensorType = ttgi::getRankedTensorType(ptr.getType()); auto matchesShape = [&refTensorType](const Value &val) { auto rttType = dyn_cast(val.getType()); return rttType && rttType.getShape() == refTensorType.getShape(); @@ -279,7 +273,7 @@ struct CoalescePass "Unexpected layout"); auto resType = cast(res.getType()); - RankedTensorType tensorType = getRankedTensorType(resType); + RankedTensorType tensorType = ttgi::getRankedTensorType(resType); res.setType(tt::PointerType::get(getNewType(tensorType, layout), resType.getAddressSpace())); } @@ -362,7 +356,7 @@ struct CoalescePass if (!ptr) return; - RankedTensorType refTensorType = getRankedTensorType(ptr.getType()); + RankedTensorType refTensorType = ttgi::getRankedTensorType(ptr.getType()); if (!refTensorType || !refTensorType.getEncoding()) return; @@ -373,8 +367,7 @@ struct CoalescePass }); LLVM_DEBUG({ - DBGS() << "\nlayoutMap:" - << "\n"; + DBGS() << "\nlayoutMap:" << "\n"; for (auto [op, encoding] : layoutMap) { DBGS() << "op: " << *op << "\n"; DBGS() << "encoding: " << encoding << "\n\n"; From 95f5832e5d5376cb7365de2057977b42257d34b0 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 23 Oct 2024 16:12:15 +0000 Subject: [PATCH 19/32] Fix pre_commit Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index 7753ca6707..7f52090f4e 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -367,7 +367,8 @@ struct CoalescePass }); LLVM_DEBUG({ - DBGS() << "\nlayoutMap:" << "\n"; + DBGS() << "\nlayoutMap:" + << "\n"; for (auto [op, encoding] : layoutMap) { DBGS() << "op: " << *op << "\n"; DBGS() << "encoding: " << encoding << "\n\n"; From 3636befd79d85b1daa818679ea7bde2cc2324b8f Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 24 Oct 2024 20:51:26 +0000 Subject: [PATCH 20/32] Make isExpensiveLoadOrStore consider blocked pointers load and stores Signed-off-by: Tiotto, Ettore --- .../RemoveLayoutConversions.cpp | 56 ++++++++++++++++++- .../lib/TritonIntelGPUTransforms/Utility.cpp | 10 +++- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index e91cfa34c0..1546dc5031 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -305,8 +305,29 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { // Return true if the op is an op with a layout we don't want to change. We will // propagate the layout starting from anchor ops. bool isLayoutAnchor(Operation *op) { - if (isa(op)) + if (isa(op)) { +#ifdef HACK + // Note: currently block ptr loads are always considered not expensive and + // therefore they are never layout anchors. + Value base = op->getOperand(0); + auto parentLoop = op->getParentOfType(); + bool isInLoop = parentLoop != nullptr; + bool isTensorPtrLoad = mlir::triton::isTensorPointerType(base.getType()); + + if (!isTensorPtrLoad) + ttgi::isExpensiveLoadOrStore(op); + + // HACK: consider block ptr loads expensive if they are in a loop. + return isInLoop; +#else return ttgi::isExpensiveLoadOrStore(op); +#endif + } + + if (isa(op)) { + return ttgi::isExpensiveLoadOrStore(op); + } + if (isa(op)) return true; if (isa(op)) @@ -356,6 +377,17 @@ void LayoutPropagation::initAnchorLayout() { } } }); + +#if 0 + llvm::errs() << "Initial layouts:\n"; + for (auto &entry : layouts) { + llvm::errs() << entry.first << "\n"; + for (auto &layout : entry.second.encodings) { + llvm::errs() << " " << layout << "\n"; + } + } + llvm::errs() << "\n\n"; +#endif } void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, @@ -969,8 +1001,28 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { } bool canBeRemat(Operation *op) { - if (isa(op)) + if (isa(op)) { +#ifdef HACK + // Note: currently block ptr loads are always considered not expensive and + // therefore rematerializable. + Value base = op->getOperand(0); + auto parentLoop = op->getParentOfType(); + bool isInLoop = parentLoop != nullptr; + bool isTensorPtrLoad = mlir::triton::isTensorPointerType(base.getType()); + + if (!isTensorPtrLoad) + return !ttgi::isExpensiveLoadOrStore(op); + + // HACK: consider block ptr loads expensive if they are in a loop. + return !isInLoop; +#else + return !ttgi::isExpensiveLoadOrStore(op); +#endif + } + + if (isa(op)) return !ttgi::isExpensiveLoadOrStore(op); + if (isa(op)) return false; if (isa(op)) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 759fc1782d..3f07283bfc 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -90,9 +90,15 @@ bool isExpensiveLoadOrStore(Operation *op) { if (isSingleValue(base)) return false; - // Case 2: Tensor of pointers has more threads than elements - // we can presume a high hit-rate that makes it cheap to load + // Case 2: Tensor of pointers has more threads than elements + // we can presume a high hit-rate that makes it cheap to load + +#define NEW 1 +#ifdef NEW + if (auto ptrType = getRankedTensorType(base.getType())) { +#else if (auto ptrType = dyn_cast(base.getType())) { +#endif auto mod = op->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); From db2193e4aa0eaa81c420723dc5e5a33c8665ab3a Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 25 Oct 2024 14:45:12 +0000 Subject: [PATCH 21/32] Make isExpensiveLoadOrStore consider blocked pointers load and stores Signed-off-by: Tiotto, Ettore --- .../lib/TritonIntelGPUTransforms/Utility.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 3f07283bfc..e5d8d09a03 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -11,6 +11,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h" +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -93,8 +94,22 @@ bool isExpensiveLoadOrStore(Operation *op) { // Case 2: Tensor of pointers has more threads than elements // we can presume a high hit-rate that makes it cheap to load + // IDEA: Block pointers loads are expensive if: + // - they cannot be lowered to 2D block reads (they feed a dot operation) + // - temporarily we can look at the "triton_intel_gpu.block_io" attribute, + // if it has it it can be lowered to 2D block reads + // + // + #define NEW 1 #ifdef NEW + Attribute blockIOAttr = + op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); + if (blockIOAttr) { + llvm::errs() << "load op: " << *op << " is not expensive\n"; + return false; + } + if (auto ptrType = getRankedTensorType(base.getType())) { #else if (auto ptrType = dyn_cast(base.getType())) { From 7c9a0f93f2d2047e9334a5895cbcb8adc8cb46e2 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 25 Oct 2024 19:53:49 +0000 Subject: [PATCH 22/32] MaterializeBlockPointer fix for GEMM with 1st operand transposed Signed-off-by: Tiotto, Ettore --- .../MaterializeBlockPointer.cpp | 54 ++++++++++++++++--- 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp index 601e3694e9..999433332e 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp @@ -12,6 +12,7 @@ using namespace mlir; namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; namespace ttgi = mlir::triton::gpu::intel; namespace mlir::triton::gpu::intel { @@ -37,7 +38,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass return; MLIRContext *context = &getContext(); - mod.walk([context](tt::LoadOp loadOp) { + mod.walk([context, this](tt::LoadOp loadOp) { LDBG("Considering op: " << loadOp); Value ptr = loadOp.getPtr(); @@ -51,7 +52,6 @@ struct TritonIntelGPUMaterializeBlockPointerPass LDBG("Found make tensor ptr op: " << makeTensorPtrOp); auto ptrType = cast(makeTensorPtrOp.getType()); auto tensorType = cast(ptrType.getPointeeType()); - auto dotLayout = ttgi::getDotEncoding(tensorType); Operation::operand_range shape = makeTensorPtrOp.getShape(); unsigned rank = shape.size(); @@ -100,11 +100,11 @@ struct TritonIntelGPUMaterializeBlockPointerPass return; const bool isRowMajor = fastChangeDim == rank - 1; - if (dotLayout) { - // Check if the load is being used in a dot layout, and if so is this - // the first op and is it a transposed row major matrix. If so, skip - // the block ptr attribute as performance is worse than if we remove - // the tensor pointer + if (auto dotLayout = getDotLayout(loadOp)) { + // Check if the load is being used by a tt.dot operation, and if so is + // this the first operand and is it a transposed row major matrix. If + // so, skip the block ptr attribute as performance is worse than if we + // remove the tensor pointer. LDBG("dotLayout: " << *dotLayout); const unsigned opIdx = dotLayout->getOpIdx(); auto dotOrder = dotLayout->getThreadOrder(); @@ -122,6 +122,46 @@ struct TritonIntelGPUMaterializeBlockPointerPass } }); } + +private: + // Return the load layout if it is a dot layout. If it is not, check if the + // load result is converted to a dot layout. If so, return the dot layout, + // otherwise return nullopt. + std::optional + getDotLayout(tt::LoadOp loadOp) const { + Value ptr = loadOp.getPtr(); + if (!tt::isTensorPointerType(ptr.getType())) + return nullptr; + + RankedTensorType tensorType = ttgi::getRankedTensorType(ptr.getType()); + auto dotLayout = ttgi::getDotEncoding(tensorType); + if (dotLayout) + return dotLayout; + + auto allUsersAreConvertOps = [](Operation::user_range users) { + return llvm::all_of(users, [](Operation *user) { + return isa(user); + }); + }; + + auto allUserHaveIdenticalLayout = [](Operation::user_range users) { + Attribute firstUserLayout = + cast(*users.begin()).getType().getEncoding(); + return llvm::all_of(users, [&firstUserLayout](Operation *user) { + return firstUserLayout == + cast(user).getType().getEncoding(); + }); + }; + + Operation::user_range users = loadOp->getUsers(); + if (allUsersAreConvertOps(users) && allUserHaveIdenticalLayout(users)) { + Attribute firstUserLayout = + cast(*users.begin()).getType().getEncoding(); + return dyn_cast(firstUserLayout); + } + + return nullptr; + } }; } // anonymous namespace From cbc630b45c31c24c46ebc5bee5a10e23f5e7c795 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 25 Oct 2024 21:28:18 +0000 Subject: [PATCH 23/32] MaterializeBlockPointer fix for GEMM with 1st operand transposed Signed-off-by: Tiotto, Ettore --- .../MaterializeBlockPointer.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp index 999433332e..23788039f8 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp @@ -4,7 +4,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Visitors.h" #include "triton/Analysis/Utility.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include #define DEBUG_TYPE "tritonintelgpu-materialize-block-pointer" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -131,9 +133,12 @@ struct TritonIntelGPUMaterializeBlockPointerPass getDotLayout(tt::LoadOp loadOp) const { Value ptr = loadOp.getPtr(); if (!tt::isTensorPointerType(ptr.getType())) - return nullptr; + return std::nullopt; RankedTensorType tensorType = ttgi::getRankedTensorType(ptr.getType()); + if (!tensorType) + return std::nullopt; + auto dotLayout = ttgi::getDotEncoding(tensorType); if (dotLayout) return dotLayout; @@ -154,13 +159,15 @@ struct TritonIntelGPUMaterializeBlockPointerPass }; Operation::user_range users = loadOp->getUsers(); - if (allUsersAreConvertOps(users) && allUserHaveIdenticalLayout(users)) { + if (!users.empty() && allUsersAreConvertOps(users) && + allUserHaveIdenticalLayout(users)) { Attribute firstUserLayout = cast(*users.begin()).getType().getEncoding(); - return dyn_cast(firstUserLayout); + return llvm::dyn_cast_if_present( + firstUserLayout); } - return nullptr; + return std::nullopt; } }; From 0215a163df035b45b7979a044b6fce8a88dea82d Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 28 Oct 2024 17:55:52 +0000 Subject: [PATCH 24/32] Fix unit tests Signed-off-by: Tiotto, Ettore --- .../backward_combine_dpas_dot_layout.mlir | 21 +++++++++---------- test/TritonIntelGPU/combine.mlir | 14 ++++++------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir b/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir index d6f43af96d..cf4e3ad8de 100644 --- a/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir +++ b/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir @@ -47,16 +47,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // CHECK: %[[VAL_40:.*]] = tt.make_tensor_ptr %{{.*}}, {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}] {order = array} : >> %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > // CHECK: %[[VAL_41:.*]]:3 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %[[VAL_36]], %{{.*}} = %[[VAL_40]]) -> (tensor<64x256xf32, #[[DPAS]]>, !tt.ptr>>, !tt.ptr>>) : i32 { - // CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array} : !tt.ptr>> - // CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array} : !tt.ptr>> + // CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + // CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> // CHECK-NOT: triton_gpu.convert_layout // CHECK-NEXT: %[[VAL_48:.*]] = tt.dot %[[VAL_46]], %[[VAL_47]], %{{.*}}, inputPrecision = tf32 : tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<64x256xf32, #[[DPAS]]> // CHECK: %[[VAL_49:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : >> // CHECK: %[[VAL_50:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : >> // CHECK: scf.yield %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x256xf32, #[[DPAS]]>, !tt.ptr>>, !tt.ptr>> %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> + %28 = tt.load %arg11 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %29 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0> %31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1> %32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas> @@ -130,7 +130,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr> } %24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas> - // CHECK-NOT: triton_gpu.convert_layout %25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1> %26 = arith.extsi %arg8 : i32 to i64 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > @@ -147,6 +146,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // COM: Checks that DPAS encoding has been forwarded to the store op // COM: The `tt.make_tensor_ptr` has multiple users (the storeOp + another OP) // COM: The initial `tt.make_tensor_ptr` with non-DPAS encoding must be kept. +// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> @@ -188,8 +188,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %21 = arith.extsi %arg7 : i32 to i64 %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> + %28 = tt.load %arg11 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %29 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0> %31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1> %32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas> @@ -198,11 +198,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr> } %24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas> - // CHECK-NOT: triton_gpu.convert_layout %25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1> %26 = arith.extsi %arg8 : i32 to i64 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > - // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > %27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array} : > // CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array} : !tt.ptr> tt.store %27, %25 {boundaryCheck = array} : !tt.ptr> @@ -243,8 +242,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > %22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr>, !tt.ptr>) : i32 { - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> + %28 = tt.load %arg11 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major" } : !tt.ptr> + %29 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %36 = triton_gpu.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas> %30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0> %31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1> diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index 64f3193653..318d957c63 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -2324,23 +2324,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked2> %0 = tt.get_program_id x : i32 %1 = tt.get_program_id y : i32 - // CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : >> - // CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : >> + // CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : > + // CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : > %12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array} : > %14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array} : > - // CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>>, !tt.ptr>>) : i32 { + // CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>, !tt.ptr>) : i32 { %15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst_1, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #blocked2>, !tt.ptr>, !tt.ptr>) : i32 { %47 = tt.load %arg5 : !tt.ptr> %48 = tt.load %arg6 : !tt.ptr> - // CHEKC-NOT: triton_gpu.convert_layout %49 = triton_gpu.convert_layout %arg4 : tensor<256x256xf32, #blocked2> -> tensor<256x256xf32, #mma> %50 = triton_gpu.convert_layout %47 : tensor<256x32xbf16, #blocked3> -> tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %51 = triton_gpu.convert_layout %48 : tensor<32x256xbf16, #blocked2> -> tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %52 = tt.dot %50, %51, %49, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> %53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked2> - // CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : >> - // CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : >> - // CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>>, !tt.ptr>> + // CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : > + // CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : > + // CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>, !tt.ptr> %54 = tt.advance %arg5, [%c0_i32, %c128_i32] : > %55 = tt.advance %arg6, [%c128_i32, %c0_i32] : > scf.yield %53, %54, %55 : tensor<256x256xf32, #blocked2>, !tt.ptr>, !tt.ptr> @@ -2348,7 +2347,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 %16 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> %32 = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked2> %38 = arith.cmpi slt, %16, %cst : tensor<256xi32, #blocked> - // CHEKC-NOT: triton_gpu.convert_layout %39 = triton_gpu.convert_layout %38 : tensor<256xi1, #blocked> -> tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> %40 = tt.expand_dims %39 {axis = 0 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x256xi1, #blocked4> %41 = triton_gpu.convert_layout %40 : tensor<1x256xi1, #blocked4> -> tensor<1x256xi1, #blocked2> From ae3d625265f6188691fe2010b7b1b033aec60d21 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 28 Oct 2024 17:56:34 +0000 Subject: [PATCH 25/32] Fix performance regression for gemm-preop-exp Signed-off-by: Tiotto, Ettore --- .../MaterializeBlockPointer.cpp | 9 +++-- .../Pipeliner/MatmulLoopPipeline.cpp | 35 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp index 23788039f8..f8f554bb02 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp @@ -102,7 +102,9 @@ struct TritonIntelGPUMaterializeBlockPointerPass return; const bool isRowMajor = fastChangeDim == rank - 1; - if (auto dotLayout = getDotLayout(loadOp)) { + std::optional dotLayout = + getDotLayout(loadOp); + if (dotLayout) { // Check if the load is being used by a tt.dot operation, and if so is // this the first operand and is it a transposed row major matrix. If // so, skip the block ptr attribute as performance is worse than if we @@ -163,8 +165,9 @@ struct TritonIntelGPUMaterializeBlockPointerPass allUserHaveIdenticalLayout(users)) { Attribute firstUserLayout = cast(*users.begin()).getType().getEncoding(); - return llvm::dyn_cast_if_present( - firstUserLayout); + if (isa(firstUserLayout)) + return dyn_cast(firstUserLayout); + return std::nullopt; } return std::nullopt; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp index 7f5789db3a..26190b98c3 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp @@ -10,6 +10,8 @@ #include "llvm/Support/Debug.h" #define DEBUG_TYPE "tritonintelgpu-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; namespace tt = mlir::triton; @@ -55,30 +57,25 @@ static ttg::DotOperandEncodingAttr getDotEncodingFromUser(Operation *user) { if (!tensorType) return nullptr; - if (isa(tensorType.getEncoding())) - return allTransitiveUsesHaveDotEncoding(res); - - return llvm::dyn_cast_or_null( - tensorType.getEncoding()); + Attribute layout = tensorType.getEncoding(); + return isa(layout) + ? allTransitiveUsesHaveDotEncoding(res) + : llvm::dyn_cast_or_null(layout); } /// If all the transitive uses of the given value are used by a convert to the /// same dot operand encoding, return the encoding. Otherwise return nullptr. static ttg::DotOperandEncodingAttr allTransitiveUsesHaveDotEncoding(Value val) { ttg::DotOperandEncodingAttr attr{nullptr}; - LLVM_DEBUG(llvm::dbgs() << "Checking users of " << val << "\n"); + LDBG("Checking users of " << val); for (Operation *user : val.getUsers()) { - ttg::DotOperandEncodingAttr dotAttr; - if (isa(user)) { - auto tensorType = cast(val.getType()); - dotAttr = dyn_cast(tensorType.getEncoding()); - } else { - dotAttr = getDotEncodingFromUser(user); - } + ttg::DotOperandEncodingAttr dotAttr = + isa(user) + ? dyn_cast( + cast(val.getType()).getEncoding()) + : getDotEncodingFromUser(user); if (!dotAttr || (attr != nullptr && attr != dotAttr)) { - LLVM_DEBUG({ - llvm::dbgs() << "no dot attribute found for user: " << user << "\n"; - }); + LDBG("no dot attribute found for user: " << *user); return nullptr; } attr = dotAttr; @@ -292,14 +289,14 @@ bool ttgi::preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, SmallVector loads; collectOpsToPipeline(forOp, loads, supportRegularPtr); if (loads.empty()) { - LLVM_DEBUG(llvm::dbgs() << "No loads to pipeline\n"); + LDBG("No loads to pipeline"); return false; } LLVM_DEBUG({ - llvm::dbgs() << "Loads to pipeline:\n"; + DBGS() << "Loads to pipeline:\n"; for (const LoadDotOperand &load : loads) - llvm::dbgs() << " " << *load.load << "\n"; + DBGS() << " " << *load.load << "\n"; }); // 2. Create the prefetching operations for the loads collected. From 22b7ec9bc5bf104ba167b5e068e3e4e0f613cf85 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 28 Oct 2024 20:16:31 +0000 Subject: [PATCH 26/32] Reduce PR footprint Signed-off-by: Tiotto, Ettore --- .../RemoveLayoutConversions.cpp | 56 +------------------ .../lib/TritonIntelGPUTransforms/Utility.cpp | 28 +++------- 2 files changed, 10 insertions(+), 74 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 1546dc5031..e91cfa34c0 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -305,29 +305,8 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { // Return true if the op is an op with a layout we don't want to change. We will // propagate the layout starting from anchor ops. bool isLayoutAnchor(Operation *op) { - if (isa(op)) { -#ifdef HACK - // Note: currently block ptr loads are always considered not expensive and - // therefore they are never layout anchors. - Value base = op->getOperand(0); - auto parentLoop = op->getParentOfType(); - bool isInLoop = parentLoop != nullptr; - bool isTensorPtrLoad = mlir::triton::isTensorPointerType(base.getType()); - - if (!isTensorPtrLoad) - ttgi::isExpensiveLoadOrStore(op); - - // HACK: consider block ptr loads expensive if they are in a loop. - return isInLoop; -#else + if (isa(op)) return ttgi::isExpensiveLoadOrStore(op); -#endif - } - - if (isa(op)) { - return ttgi::isExpensiveLoadOrStore(op); - } - if (isa(op)) return true; if (isa(op)) @@ -377,17 +356,6 @@ void LayoutPropagation::initAnchorLayout() { } } }); - -#if 0 - llvm::errs() << "Initial layouts:\n"; - for (auto &entry : layouts) { - llvm::errs() << entry.first << "\n"; - for (auto &layout : entry.second.encodings) { - llvm::errs() << " " << layout << "\n"; - } - } - llvm::errs() << "\n\n"; -#endif } void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, @@ -1001,28 +969,8 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { } bool canBeRemat(Operation *op) { - if (isa(op)) { -#ifdef HACK - // Note: currently block ptr loads are always considered not expensive and - // therefore rematerializable. - Value base = op->getOperand(0); - auto parentLoop = op->getParentOfType(); - bool isInLoop = parentLoop != nullptr; - bool isTensorPtrLoad = mlir::triton::isTensorPointerType(base.getType()); - - if (!isTensorPtrLoad) - return !ttgi::isExpensiveLoadOrStore(op); - - // HACK: consider block ptr loads expensive if they are in a loop. - return !isInLoop; -#else - return !ttgi::isExpensiveLoadOrStore(op); -#endif - } - - if (isa(op)) + if (isa(op)) return !ttgi::isExpensiveLoadOrStore(op); - if (isa(op)) return false; if (isa(op)) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index e5d8d09a03..7fe296695a 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -86,34 +86,22 @@ bool isExpensiveLoadOrStore(Operation *op) { "Expecting Triton LoadOp or StoreOp"); Value base = op->getOperand(0); - // Case 1: A size 1 tensor is not expensive since all threads will load the - // same + // A size 1 tensor is not expensive since all threads will load the same + // value. if (isSingleValue(base)) return false; - // Case 2: Tensor of pointers has more threads than elements - // we can presume a high hit-rate that makes it cheap to load - - // IDEA: Block pointers loads are expensive if: - // - they cannot be lowered to 2D block reads (they feed a dot operation) - // - temporarily we can look at the "triton_intel_gpu.block_io" attribute, - // if it has it it can be lowered to 2D block reads - // - // - -#define NEW 1 -#ifdef NEW + // Loads that use a block pointer are expensive if they cannot be lowered to + // 2D block read operations. Temporarily leverage the + // "triton_intel_gpu.block_io" attribute to filter out inexpensive loads. Attribute blockIOAttr = op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); - if (blockIOAttr) { - llvm::errs() << "load op: " << *op << " is not expensive\n"; + if (blockIOAttr) return false; - } + // Loads that use more threads than elements can be presumed to have a high + // hit-rate that makes them cheap to load. if (auto ptrType = getRankedTensorType(base.getType())) { -#else - if (auto ptrType = dyn_cast(base.getType())) { -#endif auto mod = op->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); From 4991020d2392c02a5a1209a762ef7f3bbd40b090 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 28 Oct 2024 22:24:10 +0000 Subject: [PATCH 27/32] Remove RewriteTensorPointer from the optimization pipeline Signed-off-by: Tiotto, Ettore --- third_party/intel/backend/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 86948112b9..1f5c10ba7c 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -235,7 +235,8 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) - intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) + if os.getenv("TRITON_INTEL_REWRITE_TENSOR_POINTER", "0") == "1": + intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) intel.passes.ttgpuir.add_coalesce(pm) From 9521870ff79c21f1ab1979fb5597da1dbbf43833 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 30 Oct 2024 18:52:41 +0000 Subject: [PATCH 28/32] Disable address payload opt experiment Signed-off-by: Tiotto, Ettore --- third_party/intel/include/Analysis/AxisInfo.h | 168 ++---------------- .../include/Dialect/TritonIntelGPU/IR/Utils.h | 2 +- third_party/intel/lib/Analysis/AxisInfo.cpp | 116 +----------- .../LoadStoreOpToLLVM.cpp | 136 +++++++------- .../PatternTritonGPUOpToLLVM.h | 4 +- .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 4 +- 6 files changed, 94 insertions(+), 336 deletions(-) diff --git a/third_party/intel/include/Analysis/AxisInfo.h b/third_party/intel/include/Analysis/AxisInfo.h index 1fbaba2e0c..c159db9785 100644 --- a/third_party/intel/include/Analysis/AxisInfo.h +++ b/third_party/intel/include/Analysis/AxisInfo.h @@ -1,157 +1,10 @@ #ifndef TRITON_INTEL_ANALYSIS_AXISINFO_H #define TRITON_INTEL_ANALYSIS_AXISINFO_H -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" -#include "llvm/Support/raw_ostream.h" - -#include "mlir/Support/LLVM.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" - -#include +#include "triton/Analysis/AxisInfo.h" namespace mlir::triton::intel { -//===----------------------------------------------------------------------===// -// AxisInfo -//===----------------------------------------------------------------------===// - -/// This lattice value represents known information on the axes of a lattice. -class AxisInfo { -public: - typedef SmallVector DimVectorT; - -public: - AxisInfo() : AxisInfo({}, {}, {}) {} - - AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility, - const DimVectorT &constancy) - : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} - - AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility, - const DimVectorT &constancy, std::optional constantValue) - : contiguity(contiguity), divisibility(divisibility), - constancy(constancy), constantValue(constantValue) { - assert(divisibility.size() == contiguity.size()); - assert(constancy.size() == contiguity.size()); - } - - // contiguity[d] is the length of the shortest sequence of contiguous integers - // along dimension d. - // - // If we have an array of N elements with a contiguity value C, then the array - // can be divided into a list of N/C sequences of C contiguous elements. - // Since we have N = 2^k, C must be a power of two. - // - // For example, the 2D array - // - // [[10, 11, 12, 13, 18, 19, 20, 21], - // [20, 21, 22, 23, 28, 29, 30, 31]] - // - // has contiguity [1, 4], and - // - // [[12, 16, 20, 24], - // [13, 17, 21, 25], - // [14, 18, 22, 26], - // [15, 19, 23, 27], - // [18, 22, 26, 30], - // [19, 23, 27, 31]] - // - // has contiguity [2, 1]. - int64_t getContiguity(size_t dim) const { return contiguity[dim]; } - const DimVectorT &getContiguity() const { return contiguity; } - - // divisibility[d] is the largest power of two that divides the first element - // of all groups of length contiguity[d] along dimension d. - // - // For example, - // - // [[10, 11, 12, 13, 18, 19, 20, 21], - // [20, 21, 22, 23, 28, 29, 30, 31]] - // - // has divisibility [1, 2], and - // - // [[12, 16, 20, 24], - // [13, 17, 21, 25], - // [14, 18, 22, 26], - // [15, 19, 23, 27]] - // - // has divisibility [4, 1]. - // - // On the other hand, - // - // [0, 1, 2, 0, 4, 5, 6, 7] - // - // has divisibility 1 because its contiguity is 1. - int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } - const DimVectorT &getDivisibility() const { return divisibility; } - - // constancy[d] is the length of the shortest sequence of repeating integers - // along dimension d. - // - // This is particularly useful to infer the contiguity of operations (e.g. - // add) involving a constant. - // - // If we have an array of N elements, with a constancy value C, then the array - // can be divided into a list of N/C sequences of C elements with the same - // value. Since we have N = 2^k, C must be a power of two. - // - // For example - // - // [[8, 8, 8, 8, 12, 12, 12, 12], - // [16, 16, 16, 16, 20, 20, 20, 20]] - // - // has constancy [1, 4]. - int64_t getConstancy(size_t dim) const { return constancy[dim]; } - const DimVectorT &getConstancy() const { return constancy; } - - int getRank() const { return contiguity.size(); } - - std::optional getConstantValue() const { return constantValue; } - - template - static void - initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, - DimVectorT *divisibility, DimVectorT *constancy); - - bool operator==(const AxisInfo &other) const { - return contiguity == other.contiguity && - divisibility == other.divisibility && constancy == other.constancy && - constantValue == other.constantValue; - } - - static AxisInfo getPessimisticValueState(Value value); - - // The gcd of both arguments for each dimension - static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); - - void print(raw_ostream &os) const { - auto print = [&](StringRef name, DimVectorT vec) { - os << name << " = ["; - llvm::interleaveComma(vec, os); - os << "]"; - }; - print("contiguity", contiguity); - print(", divisibility", divisibility); - print(", constancy", constancy); - os << ", constant_value = "; - if (constantValue) - os << *constantValue; - else - os << ""; - } - -private: - DimVectorT contiguity; - DimVectorT divisibility; - DimVectorT constancy; - - // The constant value of the lattice if we can infer it. - std::optional constantValue; -}; - // Module level axis info analysis based on the call graph, assuming that we do // not have recursive functions. // @@ -159,11 +12,13 @@ class AxisInfo { // axis info based on the axis info of all the callers. In the future, we can // perform optimization using function cloning so that each call site will have // unique axis info. -using AxisInfoMapT = DenseMap; -class ModuleAxisInfoAnalysis : public CallGraph { +// using AxisInfoMapT = DenseMap; +class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis { public: explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) - : CallGraph(moduleOp) { + : triton::ModuleAxisInfoAnalysis(moduleOp) { + funcMap.clear(); + SmallVector funcs; for (auto root : getRoots()) { walk( @@ -187,10 +42,11 @@ class ModuleAxisInfoAnalysis : public CallGraph { } } - AxisInfo *getAxisInfo(Value value) { + AxisInfo *getAxisInfo(Value value) const { auto funcOp = value.getParentRegion()->getParentOfType(); - auto *axisInfoMap = getFuncData(funcOp); + auto *axisInfoMap = + const_cast(this)->getFuncData(funcOp); if (!axisInfoMap) { return nullptr; } @@ -201,9 +57,9 @@ class ModuleAxisInfoAnalysis : public CallGraph { return &(it->second); } - unsigned getPtrContiguity(Value ptr); - unsigned getPtrAlignment(Value ptr); - unsigned getMaskAlignment(Value mask); + unsigned getPtrContiguity(Value ptr) const; + unsigned getPtrAlignment(Value ptr) const; + unsigned getMaskAlignment(Value mask) const; private: void initialize(FunctionOpInterface funcOp); diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h index 6357d4a8c2..7c813a64fa 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h @@ -28,7 +28,7 @@ inline unsigned getNumElementsPerThread( ? cast(cast(valTy).getPointeeType()) : cast(valTy); auto shapePerCTA = getShapePerCTA(ty); - mlir::triton::intel::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + mlir::triton::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); unsigned elemNumBits = getElementBitWidth(ty); unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 378ba01442..6d31af31b7 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -1159,113 +1159,7 @@ void AxisInfoAnalysis::visitForOpInductionVar( } // anonymous namespace -template -void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, - DimVectorT *contiguity, - DimVectorT *divisibility, - DimVectorT *constancy) { - // liast of attributes that we care about - SmallVector> retVecs; - retVecs.push_back({contiguity, "tt.contiguity"}); - retVecs.push_back({divisibility, "tt.divisibility"}); - retVecs.push_back({constancy, "tt.constancy"}); - // initialize attributes one by one - for (auto [vec, attrName] : retVecs) { - Attribute attr = funcOp.getArgAttr(argNumber, attrName); - if (auto int_attr = dyn_cast_or_null(attr)) - *vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue()); - if (auto dense_attr = dyn_cast_or_null(attr)) { - auto vals = dense_attr.getValues(); - *vec = DimVectorT(vals.begin(), vals.end()); - } - } -} - -/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { - auto rank = 1; - if (TensorType ty = dyn_cast(value.getType())) - rank = ty.getRank(); - if (triton::PointerType ty = dyn_cast(value.getType())) - if (TensorType elemTy = dyn_cast(ty.getPointeeType())) - rank = elemTy.getRank(); - - DimVectorT knownContiguity(rank, 1); - DimVectorT knownDivisibility(rank, 1); - DimVectorT knownConstancy(rank, 1); - - BlockArgument blockArg = dyn_cast(value); - - if (blockArg && blockArg.getOwner()->isEntryBlock()) { - Operation *op = blockArg.getOwner()->getParentOp(); - if (auto fun = dyn_cast(op)) - initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, - &knownContiguity, &knownDivisibility, - &knownConstancy); - // llvm codegen check alignment to generate vector load/store - // would be nice if this wasn't the case - else if (auto fun = dyn_cast(op)) - initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, - &knownContiguity, &knownDivisibility, - &knownConstancy); - else if (isa(op)) { - // scf::ForOp, scf::IfOp, scf::WhileOp - // Control flow operations are initialized with "unknown" state: - // the maximum possible divisibility, contiguity, and constancy. - knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); - knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); - knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); - } - } else if (Operation *op = value.getDefiningOp()) { - if (isa(op)) { - // scf::ForOp, scf::IfOp, scf::WhileOp - // Control flow operations are initialized with "unknown" state: - // the maximum possible divisibility, contiguity, and constancy. - knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); - knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); - knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); - } - // Other operations are conservatively initialized with the lowest possible - // divisibility, contiguity, and constancy unless they have specified. - if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { - auto vals = cast(attr).getValues(); - knownDivisibility = DimVectorT(vals.begin(), vals.end()); - } - if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { - auto vals = cast(attr).getValues(); - knownContiguity = DimVectorT(vals.begin(), vals.end()); - } - if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { - auto vals = cast(attr).getValues(); - knownConstancy = DimVectorT(vals.begin(), vals.end()); - } - } - - return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); -} - -/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { - // If one argument is not initialized, return the other. - if (lhs.getRank() == 0) - return rhs; - if (rhs.getRank() == 0) - return lhs; - DimVectorT contiguity; - DimVectorT divisibility; - DimVectorT constancy; - for (auto d = 0; d < lhs.getRank(); ++d) { - contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); - divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); - constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); - } - std::optional constantValue; - if (lhs.getConstantValue().has_value() && - rhs.getConstantValue().has_value() && - lhs.getConstantValue() == rhs.getConstantValue()) - constantValue = lhs.getConstantValue(); - return AxisInfo(contiguity, divisibility, constancy, constantValue); -} - -unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { +unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) const { auto tensorTy = ttgi::getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; @@ -1287,7 +1181,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { return contiguity; } -unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { +unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) const { auto tensorTy = ttgi::getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; @@ -1298,7 +1192,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { auto order = triton::gpu::getOrder(layout); auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); auto maxContig = axisInfo->getContiguity(order[0]); - auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + unsigned elemNumBits = isTensorPointerType(ptr.getType()) + ? tensorTy.getElementType().getIntOrFloatBitWidth() + : triton::getPointeeBitWidth(tensorTy); auto elemNumBytes = std::max(elemNumBits / 8, 1); auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); unsigned alignment = std::min(maxMultiple, maxContig); @@ -1315,7 +1211,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { return alignment; } -unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) const { auto tensorTy = ttgi::getRankedTensorType(mask.getType()); if (!tensorTy) return 1; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 8d7fee8e38..20ae11503c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -11,7 +11,6 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Attributes.h" using namespace mlir; using namespace mlir::triton; @@ -162,23 +161,25 @@ getWarpsPerCTA(const ArrayRef tensorShape, // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { - explicit LoadStoreConversionBase(const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass) + explicit LoadStoreConversionBase( + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass) : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {} unsigned getContiguity(Value ptr) const { - auto tensorTy = dyn_cast(ptr.getType()); - if (!tensorTy) - return 1; return axisAnalysisPass.getPtrContiguity(ptr); } unsigned getVectorSize(Value ptr) const { - auto tensorTy = dyn_cast(ptr.getType()); + auto tensorTy = getRankedTensorType(ptr.getType()); if (!tensorTy) return 1; - auto contiguity = getContiguity(ptr); - auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); + + unsigned contiguity = getContiguity(ptr); + unsigned pointeeBitWidth = + isTensorPointerType(ptr.getType()) + ? tensorTy.getElementType().getIntOrFloatBitWidth() + : triton::getPointeeBitWidth(tensorTy); // The maximum vector size is 128 bits. return std::min(128 / pointeeBitWidth, contiguity); } @@ -194,7 +195,7 @@ struct LoadStoreConversionBase { ArrayRef boundaryCheck = {}, std::optional padding = std::nullopt) const { - auto rank = tensorType.getRank(); + size_t rank = tensorType.getRank(); // The block pointer struct is expected to have the following layout: // Struct { // Value offset[rank]; @@ -290,7 +291,7 @@ struct LoadStoreConversionBase { } protected: - ModuleAxisInfoAnalysis &axisAnalysisPass; + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass; const triton::intel::TargetInfo &targetInfo; }; @@ -300,10 +301,11 @@ struct PrefetchOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::gpu::intel::PrefetchOp>::ConvertTritonGPUOpToLLVMPattern; - PrefetchOpConversion(TritonGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + PrefetchOpConversion( + TritonGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -476,10 +478,11 @@ struct LoadOpConversion using ValueTable = std::map, Value>; - LoadOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + LoadOpConversion( + TritonIntelGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -809,6 +812,10 @@ struct LoadOpConversion LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (isTensorPointerType(op.getPtr().getType()) && + rewriteTensorPointerLoad(op, adaptor, rewriter).succeeded()) + return success(); + auto loc = op->getLoc(); auto typeConverter = getTypeConverter(); auto *ctx = rewriter.getContext(); @@ -819,26 +826,23 @@ struct LoadOpConversion unsigned numElems = getTotalElemsPerThread(op.getType()); unsigned vec = 1; - SmallVector ptrElems; - SmallVector maskElems; - + SmallVector ptrElems, maskElems, otherElems; bool otherIsSplatConstInt = false; int64_t splatVal = 0; - SmallVector otherElems; if (isTensorPointerType(op.getPtr().getType())) { - if (rewriteTensorPointerLoad(op, adaptor, rewriter).succeeded()) { - return success(); - } else { - // TODO: (johnlu) set the vector size > 1; Need to prove the memory is - // contiguous on the fast changing dim when fallback to gather load. - Type resultType = op.getType(); - auto tensorType = cast(resultType); - std::tie(ptrElems, maskElems, otherElems) = - convertBlockPtrToTensorOfPtr( - loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter, - op.getBoundaryCheck(), op.getPadding()); - } + Value ptr = op.getPtr(); + Value mask = op.getMask(); + Value llMask = adaptor.getMask(); + vec = getVectorSize(ptr); + if (llMask) + vec = std::min(vec, getMaskAlignment(mask)); + + Type resultType = op.getType(); + auto tensorType = cast(resultType); + std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr( + loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter, + op.getBoundaryCheck(), op.getPadding()); } else { // original values Value ptr = op.getPtr(); @@ -913,7 +917,7 @@ struct LoadOpConversion for (size_t s = 0; s < size; ++s) { Value falseVal = otherElems[vecStart + ii * size + s]; Value sVal = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), s); + rewriter, loc, typeConverter->getIndexType(), s); v = insert_element(vecTy, v, falseVal, sVal); } v = bitcast(v, IntegerType::get(ctx, width)); @@ -925,7 +929,7 @@ struct LoadOpConversion } Value iiVal = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), ii); + rewriter, loc, typeConverter->getIndexType(), ii); if (nWords > 1) { other_ = insert_element(retTy, other_, v, iiVal); } else { @@ -984,10 +988,11 @@ struct StoreOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern; - StoreOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + StoreOpConversion( + TritonIntelGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -1117,13 +1122,16 @@ struct StoreOpConversion LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - MLIRContext *ctx = rewriter.getContext(); + if (isTensorPointerType(op.getPtr().getType())) + if (rewriteTensorPointerStore(op, adaptor, rewriter).succeeded()) + return success(); - Value ptr = op.getPtr(); + Location loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); Value value = op.getValue(); - auto valueTy = value.getType(); + Value ptr = op.getPtr(); + Type valueTy = value.getType(); Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); SmallVector ptrElems; @@ -1131,17 +1139,12 @@ struct StoreOpConversion unsigned vec = 1; if (isTensorPointerType(ptr.getType())) { - if (rewriteTensorPointerStore(op, adaptor, rewriter).succeeded()) { - return success(); - } else { - // fallback to scatter store. - auto tensorType = cast(valueTy); - SmallVector dummyOther; - std::tie(ptrElems, maskElems, dummyOther) = - convertBlockPtrToTensorOfPtr(loc, adaptor.getPtr(), tensorType, - valueElemTy, rewriter, - op.getBoundaryCheck()); - } + // fallback to scatter store. + auto tensorType = cast(valueTy); + SmallVector dummyOther; + std::tie(ptrElems, maskElems, dummyOther) = convertBlockPtrToTensorOfPtr( + loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter, + op.getBoundaryCheck()); } else { Value llPtr = adaptor.getPtr(); Value llMask = adaptor.getMask(); @@ -1245,10 +1248,11 @@ struct AtomicCASOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern; - AtomicCASOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + AtomicCASOpConversion( + TritonIntelGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -1362,10 +1366,11 @@ struct AtomicRMWOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern; - AtomicRMWOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, - const triton::intel::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + AtomicRMWOpConversion( + TritonIntelGPUToLLVMTypeConverter &converter, + const triton::intel::TargetInfo &targetInfo, + const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} @@ -1625,7 +1630,8 @@ struct AtomicRMWOpConversion void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns( TritonIntelGPUToLLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { + const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { patterns.add( typeConverter, targetInfo, axisInfoAnalysis, benefit); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index b605776b52..40116a17ca 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -3,8 +3,8 @@ #include "TargetInfo.h" #include "TritonGPUToLLVMBase.h" +#include "intel/include/Analysis/AxisInfo.h" #include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h" -#include "triton/Analysis/AxisInfo.h" namespace mlir::triton::intel { @@ -53,7 +53,7 @@ void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter, void populateLoadStoreOpToLLVMPatterns( TritonIntelGPUToLLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); + const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 998b7204cf..3d3bbb3015 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -7,6 +7,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "intel/include/Analysis/AxisInfo.h" #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/GPUToTritonGEN/GPUToTritonGENPass.h" @@ -14,7 +15,6 @@ #include "intel/include/TritonIntelGPUToLLVM/Passes.h" #include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -114,7 +114,7 @@ struct ConvertTritonGPUToLLVM return signalPassFailure(); } - ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + intel::ModuleAxisInfoAnalysis axisInfoAnalysis(mod); OpBuilder::InsertPoint indexInsertPoint; RewritePatternSet patterns(context); From 00f8432d271006b46e08900dfcde66733d29cbe9 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 31 Oct 2024 19:10:53 +0000 Subject: [PATCH 29/32] Fix test_block_pointer.py:test_block_copy Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/Analysis/AxisInfo.cpp | 22 ++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 6d31af31b7..f1bf97edf0 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -558,6 +558,11 @@ class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { // If pointers and mask both have constancy properties, those properties // will also extend to output. AxisInfo ptrInfo = operands[0]->getValue(); + + llvm::errs() << "ptrInfo: "; + ptrInfo.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + std::optional maskInfo; if (operands.size() > 1) { maskInfo = operands[1]->getValue(); @@ -1030,13 +1035,24 @@ class MakeTensorPtrOpAxisInfoVisitor final strideInfo[dim].getConstantValue() == 1 ? blkShape[dim] : 1); divisibility.push_back( contiguity[dim] > 1 - ? std::min(ptrDivisibility, - strideInfo[dim == 0 ? 1 : 0].getDivisibility()[0]) + ? std::min( + ptrDivisibility, + (rank == 2 ? strideInfo[dim == 0 ? 1 : 0] : strideInfo[dim]) + .getDivisibility()[0]) : 1); constancy.push_back(1); } - return AxisInfo(contiguity, divisibility, constancy); + auto axisInfo = AxisInfo(contiguity, divisibility, constancy); + + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo.print(os); + LDBG("-- " << axisStr); + }); + + return axisInfo; } }; From 17f5b25d6a32141dbea1293c02b7a1f843cc5473 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 1 Nov 2024 15:56:31 +0000 Subject: [PATCH 30/32] Address code review comments Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/Analysis/AxisInfo.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index f1d5c94ee4..ae957eb8ab 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -559,10 +559,6 @@ class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { // will also extend to output. AxisInfo ptrInfo = operands[0]->getValue(); - llvm::errs() << "ptrInfo: "; - ptrInfo.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - std::optional maskInfo; if (operands.size() > 1) { maskInfo = operands[1]->getValue(); From 0b21a82a8d822ab3910c16cb0f4566081f8117a6 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 1 Nov 2024 17:12:04 +0000 Subject: [PATCH 31/32] Address code review comments Signed-off-by: Tiotto, Ettore --- .../test/unit/language/test_block_pointer.py | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index 5f77d5f85f..9dee5dea7f 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -7,29 +7,29 @@ @triton.jit -def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr): +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr): pid = tl.program_id(0) # We only copy half of the data to see if the padding works a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), block_shape=(BLOCK_SIZE, ), order=(0, )) b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), block_shape=(BLOCK_SIZE, ), order=(0, )) - # if padding_option is None: - a = tl.load(a_block_ptr, boundary_check=(0, )) - # else: - # a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) + if padding_option is None: + a = tl.load(a_block_ptr, boundary_check=(0, )) + else: + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) tl.store(b_block_ptr, a, boundary_check=(0, )) @pytest.mark.interpreter -@pytest.mark.parametrize("dtypes_str, n", [ # - (dtypes_str, n) - # for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"), - # ("float32", "float32"), ("bfloat16", "bfloat16")) - for dtypes_str in [("float16", "float16")] - for n in [64] +@pytest.mark.parametrize("dtypes_str, n, padding_option", [ # + (dtypes_str, n, padding) + for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"), + ("float32", "float32"), ("bfloat16", "bfloat16")) + for n in (64, 128, 256, 512, 1024) + for padding in (None, "zero", "nan") # ]) -def test_block_copy(dtypes_str, n, device): +def test_block_copy(dtypes_str, n, padding_option, device): src_dtype_str = dtypes_str[0] dst_dtype_str = dtypes_str[1] src_dtype = getattr(torch, src_dtype_str) @@ -37,23 +37,21 @@ def test_block_copy(dtypes_str, n, device): check_type_supported(src_dtype, device) check_type_supported(dst_dtype, device) if src_dtype_str in ("bool", "int16", "int32"): - # if padding_option == "nan": - # pytest.xfail("Padding with NaN is not supported for integer types") + if padding_option == "nan": + pytest.xfail("Padding with NaN is not supported for integer types") a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype) else: a = torch.randn((n, ), device=device, dtype=src_dtype) b = torch.zeros((n, ), device=device, dtype=dst_dtype) grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) - block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) a.to(dst_dtype) assert torch.all(a[0:n // 2] == b[0:n // 2]) - - -# if padding_option == "zero": -# assert torch.all(b[n // 2:n] == 0) -# elif padding_option == "nan": -# assert torch.all(torch.isnan(b[n // 2:n])) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + elif padding_option == "nan": + assert torch.all(torch.isnan(b[n // 2:n])) @triton.jit From 2d229071c3ca11189d5a0e3857ca157a7b8f59f0 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 1 Nov 2024 18:03:00 +0000 Subject: [PATCH 32/32] Add vectorization support for store as well Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/Analysis/AxisInfo.cpp | 1 - .../LoadStoreOpToLLVM.cpp | 60 +++++++------------ 2 files changed, 22 insertions(+), 39 deletions(-) diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index ae957eb8ab..7161dedf7a 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -558,7 +558,6 @@ class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { // If pointers and mask both have constancy properties, those properties // will also extend to output. AxisInfo ptrInfo = operands[0]->getValue(); - std::optional maskInfo; if (operands.size() > 1) { maskInfo = operands[1]->getValue(); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 615ccaf6d4..d5d6fae327 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -827,46 +827,34 @@ struct LoadOpConversion rewriteTensorPointerLoad(op, adaptor, rewriter).succeeded()) return success(); - auto loc = op->getLoc(); - auto typeConverter = getTypeConverter(); - auto *ctx = rewriter.getContext(); + Location loc = op->getLoc(); + TritonIntelGPUToLLVMTypeConverter *typeConverter = getTypeConverter(); + MLIRContext *ctx = rewriter.getContext(); + Value ptr = op.getPtr(); + Value mask = op.getMask(); + Value llMask = adaptor.getMask(); // Determine the vectorization size Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(op.getType())); unsigned numElems = getTotalElemsPerThread(op.getType()); - unsigned vec = 1; + unsigned vec = getVectorSize(ptr); + if (llMask) + vec = std::min(vec, getMaskAlignment(mask)); SmallVector ptrElems, maskElems, otherElems; bool otherIsSplatConstInt = false; int64_t splatVal = 0; - if (isTensorPointerType(op.getPtr().getType())) { - Value ptr = op.getPtr(); - Value mask = op.getMask(); - Value llMask = adaptor.getMask(); - vec = getVectorSize(ptr); - if (llMask) - vec = std::min(vec, getMaskAlignment(mask)); - - Type resultType = op.getType(); - auto tensorType = cast(resultType); + if (isTensorPointerType(ptr.getType())) { + auto tensorType = cast(op.getType()); std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr( loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter, op.getBoundaryCheck(), op.getPadding()); } else { - // original values - Value ptr = op.getPtr(); Value other = op.getOther(); - Value mask = op.getMask(); - - // adaptor values Value llPtr = adaptor.getPtr(); - Value llMask = adaptor.getMask(); Value llOther = adaptor.getOther(); - vec = getVectorSize(ptr); - if (llMask) - vec = std::min(vec, getMaskAlignment(mask)); // Get the LLVM values for pointers ptrElems = unpackLLElements(loc, llPtr, rewriter); @@ -1141,19 +1129,23 @@ struct StoreOpConversion return success(); Location loc = op->getLoc(); + TritonIntelGPUToLLVMTypeConverter *typeConverter = getTypeConverter(); MLIRContext *ctx = rewriter.getContext(); - Value value = op.getValue(); - Value ptr = op.getPtr(); + Value mask = op.getMask(); + Value llMask = adaptor.getMask(); + + // Determine the vectorization size + Value value = op.getValue(); Type valueTy = value.getType(); Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); - SmallVector ptrElems; - SmallVector maskElems; - unsigned vec = 1; + SmallVector ptrElems, maskElems; + unsigned vec = getVectorSize(ptr); + if (llMask) + vec = std::min(vec, getMaskAlignment(mask)); if (isTensorPointerType(ptr.getType())) { - // fallback to scatter store. auto tensorType = cast(valueTy); SmallVector dummyOther; std::tie(ptrElems, maskElems, dummyOther) = convertBlockPtrToTensorOfPtr( @@ -1161,19 +1153,11 @@ struct StoreOpConversion op.getBoundaryCheck()); } else { Value llPtr = adaptor.getPtr(); - Value llMask = adaptor.getMask(); - - vec = getVectorSize(ptr); ptrElems = unpackLLElements(loc, llPtr, rewriter); - // Determine the vectorization size if (llMask) { - Value mask = op.getMask(); maskElems = unpackLLElements(loc, llMask, rewriter); - - unsigned maskAlign = getMaskAlignment(mask); - vec = std::min(vec, maskAlign); } } @@ -1183,7 +1167,7 @@ struct StoreOpConversion assert(!maskElems.size() || valueElems.size() == maskElems.size() && "Mask size mismatch"); - Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNBits = dtsize * 8;