Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCE
target_link_libraries(triton-tensor-layout PRIVATE
TritonGPUIR
TritonNvidiaGPUIR
TritonIntelGPUIR
${triton_libs}
${conversion_libs}
${dialect_libs}
Expand Down
12 changes: 2 additions & 10 deletions bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,9 @@ static cl::opt<std::string> TensorStr(
//===--------------------------------------------------------------------===//

LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace();

// Dispatch to the corresponding dialect helper function to print the layout.
if (dialectName == "triton_gpu") {
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
return success();
}

llvm::errs() << "Unsupported tensor layout attribute: "
<< tensorType.getEncoding() << "\n";
return failure();
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
return success();
}

LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
Expand Down
312 changes: 151 additions & 161 deletions test/Conversion/intel/dot_layout_offset.mlir

Large diffs are not rendered by default.

380 changes: 260 additions & 120 deletions test/TritonIntelGPU/tritonintelgpu-convert-layout-shortcut.mlir

Large diffs are not rendered by default.

137 changes: 65 additions & 72 deletions test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-DAG: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: %[[CST_5:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK-DAG: %[[CST_6:.*]] = llvm.mlir.constant(6 : i32) : i32
// CHECK-DAG: %[[CST_7:.*]] = llvm.mlir.constant(7 : i32) : i32
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK-DAG: %[[CST_17:.*]] = llvm.mlir.constant(17 : i32) : i32
// CHECK-DAG: %[[CST_18:.*]] = llvm.mlir.constant(18 : i32) : i32
// CHECK-DAG: %[[CST_19:.*]] = llvm.mlir.constant(19 : i32) : i32
Expand All @@ -86,43 +85,46 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-DAG: %[[CST_23:.*]] = llvm.mlir.constant(23 : i32) : i32
// CHECK: %[[THREAD_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]])
// CHECK: %[[THREAD_ID_32:.*]] = llvm.trunc %[[THREAD_ID]] : i64 to i32
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_32]], %[[CST_16]] : i32
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_32]], %[[CST_16]] : i32
// CHECK: %[[VAL_29:.*]] = llvm.udiv %[[WARP_ID]], %[[CST_2]] : i32
// CHECK: %[[WARP_ID_X:.*]] = llvm.urem %[[VAL_29]], %[[CST_2]] : i32
// CHECK: %[[ROUNDED_WARP_ID_X:.*]] = llvm.urem %[[WARP_ID_X]], %[[CST_4]] : i32
// CHECK: %[[WARP_OFFSET:.*]] = llvm.mul %[[ROUNDED_WARP_ID_X]], %[[CST_8]] : i32
// CHECK: %[[LANE_ID_X:.*]] = llvm.udiv %[[LANE_ID]], %[[CST_16]] : i32
// CHECK: %[[LANE_ID_Y:.*]] = llvm.urem %[[LANE_ID]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_Y:.*]] = llvm.mul %[[LANE_ID_Y]], %[[CST_2]] : i32
// CHECK: %[[OFFSET_x:.*]] = llvm.add %[[LANE_ID_X]], %[[WARP_OFFSET]] : i32
// CHECK: %[[VAL_37:.*]] = llvm.urem %[[CST_0]], %[[CST_1]] : i32
// CHECK: %[[VAL_38:.*]] = llvm.udiv %[[CST_0]], %[[CST_1]] : i32
// CHECK: %[[VAL_39:.*]] = llvm.urem %[[VAL_38]], %[[CST_1]] : i32
// CHECK: %[[VAL_40:.*]] = llvm.urem %[[VAL_39]], %[[CST_1]] : i32
// CHECK: %[[VAL_41:.*]] = llvm.urem %[[VAL_37]], %[[CST_1]] : i32
// CHECK: %[[CTA_OFFSET_X:.*]] = llvm.mul %[[VAL_40]], %[[CST_32]] : i32
// CHECK: %[[CTA_OFFSET_Y:.*]] = llvm.mul %[[VAL_41]], %[[CST_32]] : i32
// CHECK: %[[VAL_44:.*]] = llvm.add %[[OFFSET_x]], %[[CTA_OFFSET_X]] : i32
// CHECK: %[[VAL_45:.*]] = llvm.add %[[OFFSET_Y]], %[[CTA_OFFSET_Y]] : i32
// CHECK: %[[OFFSET_X_0:.*]] = llvm.add %[[VAL_44]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_Y_0:.*]] = llvm.add %[[VAL_45]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_Y_1:.*]] = llvm.add %[[VAL_45]], %[[CST_1]] : i32
// CHECK: %[[OFFSET_X_1:.*]] = llvm.add %[[VAL_44]], %[[CST_1]] : i32
// CHECK: %[[OFFSET_X_2:.*]] = llvm.add %[[VAL_44]], %[[CST_2]] : i32
// CHECK: %[[OFFSET_X_3:.*]] = llvm.add %[[VAL_44]], %[[CST_3]] : i32
// CHECK: %[[OFFSET_X_4:.*]] = llvm.add %[[VAL_44]], %[[CST_4]] : i32
// CHECK: %[[OFFSET_X_5:.*]] = llvm.add %[[VAL_44]], %[[CST_5]] : i32
// CHECK: %[[OFFSET_X_6:.*]] = llvm.add %[[VAL_44]], %[[CST_6]] : i32
// CHECK: %[[OFFSET_X_7:.*]] = llvm.add %[[VAL_44]], %[[CST_7]] : i32
// CHECK: %[[OFFSET_X_8:.*]] = llvm.add %[[VAL_44]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_X_9:.*]] = llvm.add %[[VAL_44]], %[[CST_17]] : i32
// CHECK: %[[OFFSET_X_10:.*]] = llvm.add %[[VAL_44]], %[[CST_18]] : i32
// CHECK: %[[OFFSET_X_11:.*]] = llvm.add %[[VAL_44]], %[[CST_19]] : i32
// CHECK: %[[OFFSET_X_12:.*]] = llvm.add %[[VAL_44]], %[[CST_20]] : i32
// CHECK: %[[OFFSET_X_13:.*]] = llvm.add %[[VAL_44]], %[[CST_21]] : i32
// CHECK: %[[OFFSET_X_14:.*]] = llvm.add %[[VAL_44]], %[[CST_22]] : i32
// CHECK: %[[OFFSET_X_15:.*]] = llvm.add %[[VAL_44]], %[[CST_23]] : i32
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_32]], %[[CST_16]] : i32
// CHECK: %[[VAL_27:.*]] = llvm.and %[[LANE_ID]], %[[CST_1]] : i32
// CHECK: %[[VAL_28:.*]] = llvm.icmp "eq" %[[VAL_27]], %[[CST_0]] : i32
// CHECK: %[[VAL_29:.*]] = llvm.select %[[VAL_28]], %[[CST_0]], %[[CST_2]] : i1, i32
// CHECK: %[[VAL_30:.*]] = llvm.xor %[[CST_0]], %[[VAL_29]] : i32
// CHECK: %[[VAL_31:.*]] = llvm.and %[[LANE_ID]], %[[CST_2]] : i32
// CHECK: %[[VAL_32:.*]] = llvm.icmp "eq" %[[VAL_31]], %[[CST_0]] : i32
// CHECK: %[[VAL_33:.*]] = llvm.select %[[VAL_32]], %[[CST_0]], %[[CST_4]] : i1, i32
// CHECK: %[[VAL_34:.*]] = llvm.xor %[[VAL_30]], %[[VAL_33]] : i32
// CHECK: %[[VAL_35:.*]] = llvm.and %[[LANE_ID]], %[[CST_4]] : i32
// CHECK: %[[VAL_36:.*]] = llvm.icmp "eq" %[[VAL_35]], %[[CST_0]] : i32
// CHECK: %[[VAL_37:.*]] = llvm.select %[[VAL_36]], %[[CST_0]], %[[CST_8]] : i1, i32
// CHECK: %[[VAL_38:.*]] = llvm.xor %[[VAL_34]], %[[VAL_37]] : i32
// CHECK: %[[VAL_39:.*]] = llvm.and %[[LANE_ID]], %[[CST_8]] : i32
// CHECK: %[[VAL_40:.*]] = llvm.icmp "eq" %[[VAL_39]], %[[CST_0]] : i32
// CHECK: %[[VAL_41:.*]] = llvm.select %[[VAL_40]], %[[CST_0]], %[[CST_16]] : i1, i32
// CHECK: %[[VAL_42:.*]] = llvm.xor %[[VAL_38]], %[[VAL_41]] : i32
// CHECK: %[[VAL_43:.*]] = llvm.and %[[WARP_ID]], %[[CST_2]] : i32
// CHECK: %[[VAL_44:.*]] = llvm.icmp "eq" %[[VAL_43]], %[[CST_0]] : i32
// CHECK: %[[VAL_45:.*]] = llvm.select %[[VAL_44]], %[[CST_0]], %[[CST_8]] : i1, i32
// CHECK: %[[VAL_46:.*]] = llvm.xor %[[CST_0]], %[[VAL_45]] : i32
// CHECK: %[[OFFSET_X_0:.*]] = llvm.xor %[[VAL_46]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_Y_0:.*]] = llvm.xor %[[VAL_42]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_Y_1:.*]] = llvm.xor %[[VAL_42]], %[[CST_1]] : i32
// CHECK: %[[OFFSET_X_1:.*]] = llvm.xor %[[VAL_46]], %[[CST_1]] : i32
// CHECK: %[[OFFSET_X_2:.*]] = llvm.xor %[[VAL_46]], %[[CST_2]] : i32
// CHECK: %[[OFFSET_X_3:.*]] = llvm.xor %[[VAL_46]], %[[CST_3]] : i32
// CHECK: %[[OFFSET_X_4:.*]] = llvm.xor %[[VAL_46]], %[[CST_4]] : i32
// CHECK: %[[OFFSET_X_5:.*]] = llvm.xor %[[VAL_46]], %[[CST_5]] : i32
// CHECK: %[[OFFSET_X_6:.*]] = llvm.xor %[[VAL_46]], %[[CST_6]] : i32
// CHECK: %[[OFFSET_X_7:.*]] = llvm.xor %[[VAL_46]], %[[CST_7]] : i32
// CHECK: %[[OFFSET_X_8:.*]] = llvm.xor %[[VAL_46]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_X_9:.*]] = llvm.xor %[[VAL_46]], %[[CST_17]] : i32
// CHECK: %[[OFFSET_X_10:.*]] = llvm.xor %[[VAL_46]], %[[CST_18]] : i32
// CHECK: %[[OFFSET_X_11:.*]] = llvm.xor %[[VAL_46]], %[[CST_19]] : i32
// CHECK: %[[OFFSET_X_12:.*]] = llvm.xor %[[VAL_46]], %[[CST_20]] : i32
// CHECK: %[[OFFSET_X_13:.*]] = llvm.xor %[[VAL_46]], %[[CST_21]] : i32
// CHECK: %[[OFFSET_X_14:.*]] = llvm.xor %[[VAL_46]], %[[CST_22]] : i32
// CHECK: %[[OFFSET_X_15:.*]] = llvm.xor %[[VAL_46]], %[[CST_23]] : i32
// CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_0]], %[[OFFSET_Y_0]], {{.*}}, {{.*}})
// CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_0]], %[[OFFSET_Y_1]], {{.*}}, {{.*}})
// CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_1]], %[[OFFSET_Y_0]], {{.*}}, {{.*}})
Expand Down Expand Up @@ -172,14 +174,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-DAG: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: %[[CST_5:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK-DAG: %[[CST_6:.*]] = llvm.mlir.constant(6 : i32) : i32
// CHECK-DAG: %[[CST_7:.*]] = llvm.mlir.constant(7 : i32) : i32
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK-DAG: %[[CST_17:.*]] = llvm.mlir.constant(17 : i32) : i32
// CHECK-DAG: %[[CST_18:.*]] = llvm.mlir.constant(18 : i32) : i32
// CHECK-DAG: %[[CST_19:.*]] = llvm.mlir.constant(19 : i32) : i32
Expand All @@ -190,34 +191,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: %[[THREADS_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]])
// CHECK: %[[THREADS_ID_32:.*]] = llvm.trunc %[[THREADS_ID]] : i64 to i32
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREADS_ID_32]], %[[CST_16]] : i32
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREADS_ID_32]], %[[CST_16]] : i32
// CHECK: %[[VAL_29:.*]] = llvm.udiv %[[WARP_ID]], %[[CST_2]] : i32
// CHECK: %[[WARP_ID_X:.*]] = llvm.urem %[[VAL_29]], %[[CST_2]] : i32
// CHECK: %[[ROUNDED_WARP_ID_X:.*]] = llvm.urem %[[WARP_ID_X]], %[[CST_4]] : i32
// CHECK: %[[WARP_OFFSET_X:.*]] = llvm.mul %[[ROUNDED_WARP_ID_X]], %[[CST_8]] : i32
// CHECK: %[[LANE_OFFSET_X:.*]] = llvm.udiv %[[LANE_ID]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[LANE_OFFSET_X]], %[[WARP_OFFSET_X]] : i32
// CHECK: %[[VAL_35:.*]] = llvm.udiv %[[CST_0]], %[[CST_1]] : i32
// CHECK: %[[VAL_36:.*]] = llvm.urem %[[VAL_35]], %[[CST_1]] : i32
// CHECK: %[[VAL_37:.*]] = llvm.urem %[[VAL_36]], %[[CST_1]] : i32
// CHECK: %[[CTA_OFFSET_X:.*]] = llvm.mul %[[VAL_37]], %[[CST_32]] : i32
// CHECK: %[[VAL_39:.*]] = llvm.add %[[OFFSET_X]], %[[CTA_OFFSET_X]] : i32
// CHECK: %[[OFFSET_X_0:.*]] = llvm.add %[[VAL_39]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_X_1:.*]] = llvm.add %[[VAL_39]], %[[CST_1]] : i32
// CHECK: %[[OFFSET_X_2:.*]] = llvm.add %[[VAL_39]], %[[CST_2]] : i32
// CHECK: %[[OFFSET_X_3:.*]] = llvm.add %[[VAL_39]], %[[CST_3]] : i32
// CHECK: %[[OFFSET_X_4:.*]] = llvm.add %[[VAL_39]], %[[CST_4]] : i32
// CHECK: %[[OFFSET_X_5:.*]] = llvm.add %[[VAL_39]], %[[CST_5]] : i32
// CHECK: %[[OFFSET_X_6:.*]] = llvm.add %[[VAL_39]], %[[CST_6]] : i32
// CHECK: %[[OFFSET_X_7:.*]] = llvm.add %[[VAL_39]], %[[CST_7]] : i32
// CHECK: %[[OFFSET_X_8:.*]] = llvm.add %[[VAL_39]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_X_9:.*]] = llvm.add %[[VAL_39]], %[[CST_17]] : i32
// CHECK: %[[OFFSET_X_10:.*]] = llvm.add %[[VAL_39]], %[[CST_18]] : i32
// CHECK: %[[OFFSET_X_11:.*]] = llvm.add %[[VAL_39]], %[[CST_19]] : i32
// CHECK: %[[OFFSET_X_12:.*]] = llvm.add %[[VAL_39]], %[[CST_20]] : i32
// CHECK: %[[OFFSET_X_13:.*]] = llvm.add %[[VAL_39]], %[[CST_21]] : i32
// CHECK: %[[OFFSET_X_14:.*]] = llvm.add %[[VAL_39]], %[[CST_22]] : i32
// CHECK: %[[OFFSET_X_15:.*]] = llvm.add %[[VAL_39]], %[[CST_23]] : i32
// CHECK: %[[VAL_26:.*]] = llvm.and %[[WARP_ID]], %[[CST_2]] : i32
// CHECK: %[[VAL_27:.*]] = llvm.icmp "eq" %[[VAL_26]], %[[CST_0]] : i32
// CHECK: %[[VAL_28:.*]] = llvm.select %[[VAL_27]], %[[CST_0]], %[[CST_8]] : i1, i32
// CHECK: %[[VAL_29:.*]] = llvm.xor %[[CST_0]], %[[VAL_28]] : i32
// CHECK: %[[OFFSET_X_0:.*]] = llvm.xor %[[VAL_29]], %[[CST_0]] : i32
// CHECK: %[[OFFSET_X_1:.*]] = llvm.xor %[[VAL_29]], %[[CST_1]] : i32
// CHECK: %[[OFFSET_X_2:.*]] = llvm.xor %[[VAL_29]], %[[CST_2]] : i32
// CHECK: %[[OFFSET_X_3:.*]] = llvm.xor %[[VAL_29]], %[[CST_3]] : i32
// CHECK: %[[OFFSET_X_4:.*]] = llvm.xor %[[VAL_29]], %[[CST_4]] : i32
// CHECK: %[[OFFSET_X_5:.*]] = llvm.xor %[[VAL_29]], %[[CST_5]] : i32
// CHECK: %[[OFFSET_X_6:.*]] = llvm.xor %[[VAL_29]], %[[CST_6]] : i32
// CHECK: %[[OFFSET_X_7:.*]] = llvm.xor %[[VAL_29]], %[[CST_7]] : i32
// CHECK: %[[OFFSET_X_8:.*]] = llvm.xor %[[VAL_29]], %[[CST_16]] : i32
// CHECK: %[[OFFSET_X_9:.*]] = llvm.xor %[[VAL_29]], %[[CST_17]] : i32
// CHECK: %[[OFFSET_X_10:.*]] = llvm.xor %[[VAL_29]], %[[CST_18]] : i32
// CHECK: %[[OFFSET_X_11:.*]] = llvm.xor %[[VAL_29]], %[[CST_19]] : i32
// CHECK: %[[OFFSET_X_12:.*]] = llvm.xor %[[VAL_29]], %[[CST_20]] : i32
// CHECK: %[[OFFSET_X_13:.*]] = llvm.xor %[[VAL_29]], %[[CST_21]] : i32
// CHECK: %[[OFFSET_X_14:.*]] = llvm.xor %[[VAL_29]], %[[CST_22]] : i32
// CHECK: %[[OFFSET_X_15:.*]] = llvm.xor %[[VAL_29]], %[[CST_23]] : i32
// CHECK: %[[VAL_56:.*]] = llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_0]], {{.*}}, {{.*}})
// CHECK: %[[VAL_57:.*]] = llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_1]], {{.*}}, {{.*}})
// CHECK: %[[VAL_58:.*]] = llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_2]], {{.*}}, {{.*}})
Expand Down
Loading