Skip to content

Commit 95def77

Browse files
Reduce GenISA usages (#2283)
Since PoC productization is completed, removing GenISA lowering. Some create GenISA static functions are kept in case there are variants we need in the future that are not yet supported by OpenCL C builtins. When that happens, we should report to IGC team asap. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 97e0152 commit 95def77

File tree

4 files changed

+28
-152
lines changed

4 files changed

+28
-152
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3636
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
3737
"TRITON_INTEL_ENABLE_INSTR_SCHED",
3838
"TRITON_INTEL_ENABLE_POST_PROCESS_LLIR",
39-
"TRITON_INTEL_REDUCE_TRANSPOSE",
40-
"TRITONGEN_FORCE_GENISA"
39+
"TRITON_INTEL_REDUCE_TRANSPOSE"
4140
// clang-format on
4241
};
4342

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,26 @@
1-
// RUN: TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s --check-prefixes=CHECK,CHECK-COMMON
2-
// RUN: TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 TRITONGEN_FORCE_GENISA=1 triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s --check-prefixes=CHECK-GENISA,CHECK-COMMON
1+
// RUN: TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s
32

43
// CHECK: llvm.func spir_funccc @__builtin_IB_subgroup_block_read_ap_u8_m8k32v1(!llvm.ptr {llvm.nonnull}, i32, i32, i32) -> vector<8xi16> attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, will_return}
5-
// CHECK-GENISA: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockReadAddrPayload.v8i16.p0i8(!llvm.ptr {llvm.nonnull}, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi16> attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
6-
// CHECK-COMMON: llvm.func spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY(!llvm.ptr {llvm.nonnull}, i32) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = none>, no_unwind, will_return}
7-
// CHECK-COMMON: llvm.func spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX(!llvm.ptr {llvm.nonnull}, i32) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = none>, no_unwind, will_return}
8-
// CHECK-COMMON: llvm.func spir_funccc @__builtin_IB_subgroup_createBlock2DAddressPayload(i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
4+
// CHECK: llvm.func spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY(!llvm.ptr {llvm.nonnull}, i32) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = none>, no_unwind, will_return}
5+
// CHECK: llvm.func spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX(!llvm.ptr {llvm.nonnull}, i32) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = none>, no_unwind, will_return}
6+
// CHECK: llvm.func spir_funccc @__builtin_IB_subgroup_createBlock2DAddressPayload(i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
97

108
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
11-
// CHECK-COMMON: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
12-
// CHECK-COMMON-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
13-
// CHECK-COMMON-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32
14-
// CHECK-COMMON-DAG: [[PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
15-
// CHECK-COMMON-DAG: [[WIDTH:%.*]] = llvm.sub %arg1, [[ONE]] : i32
16-
// CHECK-COMMON-DAG: [[HEIGHT:%.*]] = llvm.sub %arg2, [[ONE]] : i32
17-
// CHECK-COMMON-DAG: [[PITCH:%.*]] = llvm.sub %arg3, [[ONE]] : i32
18-
// CHECK-COMMON-DAG: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
19-
// CHECK-COMMON-DAG: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
20-
// CHECK-COMMON-DAG: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
21-
// CHECK-COMMON: [[AP:%.*]] = llvm.call spir_funccc @__builtin_IB_subgroup_createBlock2DAddressPayload([[PTR]], [[WIDTH]], [[HEIGHT]], [[PITCH]], [[ZERO]], [[ZERO]], [[C32]], [[C8]], [[C1]]) {{.*}} : (i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr
22-
// CHECK-COMMON: llvm.call spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX([[AP]], %arg4) {{.*}} : (!llvm.ptr, i32) -> ()
23-
// CHECK-COMMON: llvm.call spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY([[AP]], %arg5) {{.*}} : (!llvm.ptr, i32) -> ()
24-
// CHECK-COMMON: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
9+
// CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
10+
// CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
11+
// CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32
12+
// CHECK-DAG: [[PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
13+
// CHECK-DAG: [[WIDTH:%.*]] = llvm.sub %arg1, [[ONE]] : i32
14+
// CHECK-DAG: [[HEIGHT:%.*]] = llvm.sub %arg2, [[ONE]] : i32
15+
// CHECK-DAG: [[PITCH:%.*]] = llvm.sub %arg3, [[ONE]] : i32
16+
// CHECK-DAG: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
17+
// CHECK-DAG: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
18+
// CHECK-DAG: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
19+
// CHECK: [[AP:%.*]] = llvm.call spir_funccc @__builtin_IB_subgroup_createBlock2DAddressPayload([[PTR]], [[WIDTH]], [[HEIGHT]], [[PITCH]], [[ZERO]], [[ZERO]], [[C32]], [[C8]], [[C1]]) {{.*}} : (i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr
20+
// CHECK: llvm.call spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX([[AP]], %arg4) {{.*}} : (!llvm.ptr, i32) -> ()
21+
// CHECK: llvm.call spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY([[AP]], %arg5) {{.*}} : (!llvm.ptr, i32) -> ()
22+
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
2523
// CHECK: llvm.call spir_funccc @__builtin_IB_subgroup_block_read_ap_u8_m8k32v1([[AP]], [[ZERO]], [[ZERO]], [[ZERO]]) {{.*}} : (!llvm.ptr, i32, i32, i32) -> vector<8xi16>
26-
// CHECK-GENISA: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockReadAddrPayload.v8i16.p0i8([[AP]], [[ZERO]], [[ZERO]], {{.*}}) {{.*}} : (!llvm.ptr, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi16>
2724
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
2825
llvm.return
2926
}
@@ -32,9 +29,6 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
3229

3330
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
3431
// CHECK: llvm.call spir_funccc @__builtin_IB_subgroup_block_read_ap_transpose_u32_m16k8v1
35-
// CHECK-GENISA-DAG: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1
36-
// CHECK-GENISA-DAG: [[FALSE:%.*]] = llvm.mlir.constant(false) : i1
37-
// CHECK-GENISA: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockReadAddrPayload.v8i32.p0i8({{.*}}, [[TRUE]], [[FALSE]], {{.*}})
3832
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
3933
llvm.return
4034
}
@@ -43,9 +37,6 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
4337

4438
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
4539
// CHECK: llvm.call spir_funccc @__builtin_IB_subgroup_block_read_ap_transform_u8_m32k16v1
46-
// CHECK-GENISA-DAG: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1
47-
// CHECK-GENISA-DAG: [[FALSE:%.*]] = llvm.mlir.constant(false) : i1
48-
// CHECK-GENISA: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockReadAddrPayload.v8i32.p0i8({{.*}}, [[FALSE]], [[TRUE]], {{.*}})
4940
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
5041
llvm.return
5142
}

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,6 @@ LogicalResult TritonGEN::Matrix2DBlockLoadOp::verify() {
316316
if (verify2DBlockLoadHWRestriction(*this).failed())
317317
return failure();
318318

319-
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA"))
320-
return success();
321-
322319
if (verifyMatrixInput(*this).failed())
323320
return failure();
324321

@@ -383,9 +380,6 @@ LogicalResult TritonGEN::Matrix2DBlockStoreOp::verify() {
383380
if (verify2DBlockStoreHWRestriction(*this).failed())
384381
return failure();
385382

386-
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA"))
387-
return success();
388-
389383
if (verifyMatrixInput(*this).failed())
390384
return failure();
391385

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 9 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -115,57 +115,14 @@ static LLVM::CallOp createDeviceFunctionCall(
115115
return callOp;
116116
}
117117

118-
static std::string getGenISATypeMangling(Type ty) {
118+
[[maybe_unused]] static std::string getGenISATypeMangling(Type ty) {
119119
if (auto vecTy = dyn_cast<VectorType>(ty))
120120
return "v" + std::to_string(vecTy.getNumElements()) +
121121
getGenISATypeMangling(vecTy.getElementType());
122122
return (ty.isInteger() ? "i" : "f") +
123123
std::to_string(ty.getIntOrFloatBitWidth());
124124
}
125125

126-
static LLVM::CallOp
127-
createGenISASubGroupReduce(TritonGEN::SubGroupReduceOp op, Value val,
128-
ConversionPatternRewriter &rewriter) {
129-
auto getKindVal = [](TritonGEN::ReduceKind kind) -> int {
130-
switch (kind) {
131-
case TritonGEN::ReduceKind::ADD:
132-
return 0;
133-
case TritonGEN::ReduceKind::MUL:
134-
return 1;
135-
case TritonGEN::ReduceKind::MIN:
136-
return 2;
137-
case TritonGEN::ReduceKind::MAX:
138-
return 3;
139-
case TritonGEN::ReduceKind::AND:
140-
return 8;
141-
case TritonGEN::ReduceKind::OR:
142-
return 6;
143-
case TritonGEN::ReduceKind::XOR:
144-
return 7;
145-
}
146-
llvm_unreachable("unsupported reduce kind");
147-
};
148-
149-
Location loc = op.getLoc();
150-
auto kind =
151-
rewriter.create<LLVM::ConstantOp>(loc, i8_ty, getKindVal(op.getKind()));
152-
153-
std::string funcName =
154-
"llvm.genx.GenISA.WaveAll." + getGenISATypeMangling(val.getType());
155-
SmallVector<Type> argTypes = {val.getType(), i8_ty, i32_ty};
156-
SmallVector<Value> args = {val, kind, i32_val(0)};
157-
158-
auto inaccessibleMemOnly = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
159-
/*other=*/LLVM::ModRefInfo::NoModRef,
160-
/*argMem=*/LLVM::ModRefInfo::NoModRef,
161-
/*inaccessibleMem=*/LLVM::ModRefInfo::ModRef);
162-
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
163-
funcAttrs.memEffectsAttr = inaccessibleMemOnly;
164-
165-
return createDeviceFunctionCall(rewriter, funcName, val.getType(), argTypes,
166-
args, {}, funcAttrs);
167-
}
168-
169126
static SmallVector<Attribute>
170127
loadCacheControlToDecoration(Builder &builder, uint32_t operandNum,
171128
TritonGEN::LoadCacheControl orig) {
@@ -249,8 +206,9 @@ static bool isOCLBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) {
249206
return false;
250207
}
251208

252-
static Value createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
253-
ConversionPatternRewriter &rewriter) {
209+
[[maybe_unused]] static Value
210+
createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
211+
ConversionPatternRewriter &rewriter) {
254212
MLIRContext *ctx = rewriter.getContext();
255213
VectorType resType = op.getRes().getType();
256214
Location loc = op->getLoc();
@@ -399,55 +357,10 @@ createBlock2DReadWithAddressPayloadUpdate(TritonGEN::Matrix2DBlockLoadOp op,
399357
paramAttrs, funcAttrs);
400358
};
401359

402-
auto createBlock2DReadGenISA = [&](Value ptr,
403-
TritonGEN::Matrix2DBlockLoadOp op) {
404-
assert(isa<LLVM::LLVMPointerType>(ptr.getType()) &&
405-
"Expecting a pointer type");
406-
407-
auto vecType = dyn_cast<VectorType>(resType);
408-
assert(vecType && vecType.getShape().size() == 1 &&
409-
"Expecting a 1D vector");
410-
411-
std::string fnName = "llvm.genx.GenISA.LSC2DBlockReadAddrPayload." +
412-
getGenISATypeMangling(vecType) + ".p0i8";
413-
414-
Value zero = i32_val(0);
415-
SmallVector<Type> argTypes{ptr.getType(), i32_ty, i32_ty, i32_ty, i32_ty,
416-
i32_ty, i32_ty, i1_ty, i1_ty, i32_ty};
417-
SmallVector<Value> args{ptr,
418-
zero, // x
419-
zero, // y
420-
i32_val(op.getElemSizeInBits()),
421-
i32_val(op.getTileWidth()),
422-
i32_val(op.getTileHeight()),
423-
i32_val(op.getVBlocks()),
424-
i1_val(op.getTranspose()),
425-
i1_val(op.getVnniTransform()),
426-
i32_val(4) /*cache*/};
427-
428-
// Function and parameters attributes.
429-
std::array<std::pair<unsigned, mlir::StringRef>, 1> paramAttrs{
430-
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
431-
432-
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
433-
/*other=*/LLVM::ModRefInfo::NoModRef,
434-
/*argMem=*/LLVM::ModRefInfo::Ref,
435-
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
436-
auto funcAttrs = noUnwindAttrs;
437-
funcAttrs.memEffectsAttr = memAttr;
438-
439-
return createDeviceFunctionCall(rewriter, fnName, resType, argTypes, args,
440-
paramAttrs, funcAttrs);
441-
};
442-
443360
Value ptr = createBlock2DAddressPayload(op);
444361
setBlock2DAddressPayload(ptr, op);
445362

446-
// TODO: Remove GenISA lowering after PoC productization is completed.
447-
char *env = std::getenv("TRITONGEN_FORCE_GENISA");
448-
const bool useGenISA = env ? (bool)std::atoi(env) : false;
449-
return (useGenISA) ? createBlock2DReadGenISA(ptr, op)
450-
: createBlock2DRead(ptr, op);
363+
return createBlock2DRead(ptr, op);
451364
}
452365

453366
static SmallVector<Attribute>
@@ -502,7 +415,7 @@ storeCacheControlToCacheControls(Builder &builder,
502415
return builder.getAttr<TritonGEN::DecorationCacheControlAttr>(decorations);
503416
}
504417

505-
static LLVM::CallOp
418+
[[maybe_unused]] static LLVM::CallOp
506419
createGenISA2DBlockWrite(TritonGEN::Matrix2DBlockStoreOp op,
507420
ConversionPatternRewriter &rewriter) {
508421
MLIRContext *ctx = rewriter.getContext();
@@ -550,7 +463,7 @@ createGenISA2DBlockWrite(TritonGEN::Matrix2DBlockStoreOp op,
550463
return call;
551464
}
552465

553-
static LLVM::CallOp
466+
[[maybe_unused]] static LLVM::CallOp
554467
createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
555468
ConversionPatternRewriter &rewriter) {
556469
MLIRContext *ctx = rewriter.getContext();
@@ -783,13 +696,6 @@ struct TritonSubGroupReduceLowering
783696
SmallVector<Value> args{val};
784697
bool useCluster = (getSubgroupSize(op) != op.getSize());
785698

786-
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA") && !useCluster) {
787-
Value result = createGenISASubGroupReduce(op, val, rewriter).getResult();
788-
result = TritonSubGroupBase::truncate(op, result, origTy, rewriter);
789-
rewriter.replaceOp(op, result);
790-
return success();
791-
}
792-
793699
std::string fnName = "sub_group_";
794700
fnName += useCluster ? "clustered_" : "non_uniform_";
795701
fnName += "reduce_" + stringifyReduceKind(op.getKind()).str();
@@ -969,9 +875,8 @@ struct TritonMatrix2DBlockLoadLowering
969875
return success();
970876
}
971877

972-
// TODO: Remove GenISA lowering after PoC productization is completed.
973-
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA") ||
974-
!isOCLBuiltinAvailable(op)) {
878+
if (!isOCLBuiltinAvailable(op)) {
879+
op.emitWarning("OpenCL API not available for this operation");
975880
rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter));
976881
return success();
977882
}
@@ -1036,12 +941,6 @@ struct TritonMatrix2DBlockStoreLowering
1036941
LogicalResult
1037942
matchAndRewrite(TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor,
1038943
ConversionPatternRewriter &rewriter) const override {
1039-
// TODO: Remove GenISA lowering after PoC productization is completed.
1040-
if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA")) {
1041-
rewriter.replaceOp(op, createGenISA2DBlockWrite(op, rewriter));
1042-
return success();
1043-
}
1044-
1045944
MLIRContext *ctx = rewriter.getContext();
1046945
Location loc = op->getLoc();
1047946

@@ -1104,13 +1003,6 @@ struct TritonMatrix2DBlockPrefetchLowering
11041003
LogicalResult
11051004
matchAndRewrite(TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor,
11061005
ConversionPatternRewriter &rewriter) const override {
1107-
// TODO: Remove GenISA lowering after PoC productization is completed.
1108-
bool useGenISA = tools::getBoolEnv("TRITONGEN_FORCE_GENISA");
1109-
if (useGenISA) {
1110-
rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter));
1111-
return success();
1112-
}
1113-
11141006
MLIRContext *ctx = rewriter.getContext();
11151007
Location loc = op->getLoc();
11161008
std::string fnName = "intel_sub_group_2d_block_prefetch_";

0 commit comments

Comments
 (0)