Skip to content

Commit 8f6e9d2

Browse files
authored
[NFC] Remove dead code related to IndexCastOp (triton-lang#5596)
IndexCast shouldn't exist at TTIR or TTGIR level
1 parent 6aa2df9 commit 8f6e9d2

File tree

9 files changed

+37
-86
lines changed

9 files changed

+37
-86
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,6 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10101010
visitors.append<CastOpAxisInfoVisitor<arith::ExtSIOp>,
10111011
CastOpAxisInfoVisitor<arith::ExtUIOp>,
10121012
CastOpAxisInfoVisitor<arith::TruncIOp>,
1013-
CastOpAxisInfoVisitor<arith::IndexCastOp>,
10141013
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
10151014
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
10161015
CastOpAxisInfoVisitor<triton::BitcastOp>>();

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -458,35 +458,6 @@ struct AbsFOpConversion
458458
return {rewriter.create<LLVM::FAbsOp>(loc, elemTy, operands[0][0])};
459459
}
460460
};
461-
/// The lowering of index_cast becomes an integer conversion since index
462-
/// becomes an integer. If the bit width of the source and target integer
463-
/// types is the same, just erase the cast. If the target type is wider,
464-
/// sign-extend the value, otherwise truncate it.
465-
struct IndexCastOpLowering
466-
: public ElementwiseOpConversionBase<arith::IndexCastOp,
467-
IndexCastOpLowering> {
468-
using Base =
469-
ElementwiseOpConversionBase<arith::IndexCastOp, IndexCastOpLowering>;
470-
using Base::Base;
471-
using Adaptor = typename Base::OpAdaptor;
472-
473-
SmallVector<Value> createDestOps(arith::IndexCastOp op, OpAdaptor adaptor,
474-
ConversionPatternRewriter &rewriter,
475-
Type elemTy, MultipleOperandsRange operands,
476-
Location loc) const {
477-
auto inElemTy =
478-
this->getTypeConverter()->convertType(getElementType(op.getIn()));
479-
unsigned targetBits = elemTy.getIntOrFloatBitWidth();
480-
unsigned sourceBits = inElemTy.getIntOrFloatBitWidth();
481-
482-
if (targetBits == sourceBits)
483-
return {operands[0][0]};
484-
if (targetBits < sourceBits)
485-
return {
486-
rewriter.create<LLVM::TruncOp>(op.getLoc(), elemTy, operands[0][0])};
487-
return {rewriter.create<LLVM::SExtOp>(op.getLoc(), elemTy, operands[0][0])};
488-
}
489-
};
490461

491462
struct SelectOpConversion
492463
: ElementwiseOpConversionBase<arith::SelectOp, SelectOpConversion> {
@@ -705,6 +676,5 @@ void mlir::triton::populateElementwiseOpToLLVMPatterns(
705676
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
706677
patterns.add<AbsIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
707678
patterns.add<AbsFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
708-
patterns.add<IndexCastOpLowering>(typeConverter, axisInfoAnalysis, benefit);
709679
patterns.add<SelectOpConversion>(typeConverter, axisInfoAnalysis, benefit);
710680
}

python/src/ir.cc

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <optional>
1+
#include <optional>
22
#include <pybind11/functional.h>
33
#include <pybind11/pybind11.h>
44
#include <pybind11/stl.h>
@@ -1033,16 +1033,6 @@ void init_triton_ir(py::module &&m) {
10331033
else
10341034
return self.create<arith::ExtUIOp>(dstType, src);
10351035
})
1036-
.def("create_to_index",
1037-
[](TritonOpBuilder &self, Value &input) -> Value {
1038-
return self.create<arith::IndexCastOp>(
1039-
self.getBuilder().getIndexType(), input);
1040-
})
1041-
.def("create_index_to_si",
1042-
[](TritonOpBuilder &self, Value &input) -> Value {
1043-
return self.create<arith::IndexCastOp>(
1044-
self.getBuilder().getI64Type(), input);
1045-
})
10461036
.def("create_fmul",
10471037
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
10481038
return self.create<arith::MulFOp>(lhs, rhs);

test/Analysis/test-alignment.mlir

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,14 @@ tt.func @for() {
473473
// CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
474474
%c_init = arith.constant dense<4> : tensor<128x32xi32>
475475
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
476-
%ub = arith.constant 128 : index
476+
%ub = arith.constant 128 : i32
477477
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
478-
%lb = arith.constant 0 : index
478+
%lb = arith.constant 0 : i32
479479
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16
480-
%step = arith.constant 16 : index
481-
%a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) {
480+
%step = arith.constant 16 : i32
481+
%a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) : i32 {
482482
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none>
483-
%t = arith.index_cast %iv : index to i32
483+
%t = arith.addi %iv, %lb : i32
484484
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
485485
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
486486
// CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
@@ -492,10 +492,12 @@ tt.func @for() {
492492
// -----
493493

494494
// CHECK-LABEL: @for_dynamic
495-
tt.func @for_dynamic(%lb: index {tt.divisibility = 16 : i32}, %step: index {tt.divisibility = 8 : i32}, %ub: index) {
496-
scf.for %iv = %lb to %ub step %step {
495+
tt.func @for_dynamic(%lb: i32 {tt.divisibility = 16 : i32}, %step: i32 {tt.divisibility = 8 : i32}, %ub: i32) {
496+
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
497+
%c0 = arith.constant 0 : i32
498+
scf.for %iv = %lb to %ub step %step : i32 {
497499
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>
498-
%t = arith.index_cast %iv : index to i32
500+
%t = arith.addi %iv, %c0 : i32
499501
}
500502
tt.return
501503
}

test/Triton/vecadd.mlir

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ module {
1818
%11 = tt.splat %cst : f32 -> tensor<256xf32>
1919
%c0_i32 = arith.constant 0 : i32
2020
%c32_i32 = arith.constant 32 : i32
21-
%12 = arith.index_cast %c0_i32 : i32 to index
22-
%13 = arith.index_cast %arg4 : i32 to index
23-
%14 = arith.index_cast %c32_i32 : i32 to index
24-
%15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) {
21+
%15:3 = scf.for %arg6 = %c0_i32 to %arg4 step %c32_i32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) : i32 {
2522
%cst_0 = arith.constant 0.000000e+00 : f32
2623
%18 = tt.splat %cst_0 : f32 -> tensor<256xf32>
2724
%19 = tt.load %arg8, %6, %18 : tensor<256x!tt.ptr<f32>>

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,23 +344,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
344344
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
345345
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
346346
// CHECK-LABEL: unsigned_ops
347-
tt.func @unsigned_ops(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32, %arg4 : f32, %arg5 : index) {
347+
tt.func @unsigned_ops(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32, %arg4 : f32) {
348348
%c5_i32 = arith.constant 5 : i32
349349
%0 = arith.ceildivui %arg2, %c5_i32 : i32
350350
%1 = arith.divui %arg3, %c5_i32 : i32
351351
%2 = arith.fptoui %arg4 : f32 to i32
352-
%3 = arith.index_castui %arg5 : index to i32
353352
%4 = arith.maxui %arg2, %arg3 : i32
354353
%5 = arith.minui %arg2, %arg3 : i32
355354
%6 = arith.remui %arg2, %c5_i32 : i32
356355
%7 = arith.shrui %arg3, %c5_i32 : i32
357356
%8 = arith.addi %0, %1 : i32
358-
%9 = arith.addi %2, %3 : i32
359357
%10 = arith.addi %4, %5 : i32
360358
%11 = arith.addi %6, %7 : i32
361-
%12 = arith.addi %8, %9 : i32
359+
%12 = arith.addi %8, %2 : i32
362360
%13 = arith.addi %10, %11 : i32
363-
%14 = arith.addi %12, %13 : i32
361+
%14 = arith.addi %8, %13 : i32
364362
%15 = tt.splat %14 : i32 -> tensor<8xi32, #blocked>
365363
%16 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked>
366364
%17 = arith.addi %15, %16 : tensor<8xi32, #blocked>

test/TritonGPU/combine.mlir

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,9 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
453453
// CHECK-NOT: ttg.convert_layout
454454
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
455455
%cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2>
456-
%c512 = arith.constant 512 : index
457-
%c30000 = arith.constant 30000 : index
458-
%c0 = arith.constant 0 : index
456+
%c512 = arith.constant 512 : i32
457+
%c30000 = arith.constant 30000 : i32
458+
%c0 = arith.constant 0 : i32
459459
%cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2>
460460
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2>
461461
%0 = tt.get_program_id x : i32
@@ -473,9 +473,8 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
473473
%12 = tt.broadcast %11 : tensor<1x1xi32, #blocked2> -> tensor<1x512xi32, #blocked2>
474474
%13 = tt.splat %arg0 : !tt.ptr<f64> -> tensor<1x512x!tt.ptr<f64>, #blocked2>
475475
%14 = tt.broadcast %7 : tensor<1x1xi1, #blocked2> -> tensor<1x512xi1, #blocked2>
476-
%15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args(%arg4 = %cst_2) -> (tensor<1x512xf64, #blocked2>) {
477-
%16 = arith.index_cast %arg3 : index to i32
478-
%17 = tt.splat %16 : i32 -> tensor<1x512xi32, #blocked2>
476+
%15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args(%arg4 = %cst_2) -> (tensor<1x512xf64, #blocked2>) : i32 {
477+
%17 = tt.splat %arg3 : i32 -> tensor<1x512xi32, #blocked2>
479478
%18 = arith.addi %17, %10 : tensor<1x512xi32, #blocked2>
480479
%19 = arith.cmpi "slt", %18, %cst_0 : tensor<1x512xi32, #blocked2>
481480
%20 = arith.addi %18, %12 : tensor<1x512xi32, #blocked2>
@@ -999,9 +998,9 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
999998
// CHECK-LABEL: cmp
1000999
module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32} {
10011000
tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
1002-
%c64 = arith.constant 64 : index
1003-
%c2048 = arith.constant 2048 : index
1004-
%c0 = arith.constant 0 : index
1001+
%c64 = arith.constant 64 : i32
1002+
%c2048 = arith.constant 2048 : i32
1003+
%c0 = arith.constant 0 : i32
10051004
%c64_i32 = arith.constant 64 : i32
10061005
%cst = arith.constant dense<-3.40282347E+38> : tensor<64x64xf32, #blocked2>
10071006
%cst_0 = arith.constant dense<4194304> : tensor<64x1xi32, #blocked2>
@@ -1036,9 +1035,8 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
10361035
%22 = arith.muli %21, %cst_0 : tensor<64x1xi32, #blocked2>
10371036
%23 = tt.broadcast %22 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2>
10381037
%24 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>, #blocked2>
1039-
%25 = scf.for %arg6 = %c0 to %c2048 step %c64 iter_args(%arg7 = %14) -> (tensor<64x64xf32, #blocked2>) {
1040-
%44 = arith.index_cast %arg6 : index to i32
1041-
%45 = tt.splat %44 : i32 -> tensor<1x64xi32, #blocked3>
1038+
%25 = scf.for %arg6 = %c0 to %c2048 step %c64 iter_args(%arg7 = %14) -> (tensor<64x64xf32, #blocked2>) : i32 {
1039+
%45 = tt.splat %arg6 : i32 -> tensor<1x64xi32, #blocked3>
10421040
%46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3>
10431041
%47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3>
10441042
%48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3>
@@ -1092,9 +1090,8 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
10921090
%41 = tt.broadcast %30 : tensor<64x1xf32, #blocked2> -> tensor<64x64xf32, #blocked2>
10931091
%42 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>, #blocked2>
10941092
%43 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #blocked2>
1095-
scf.for %arg6 = %c0 to %c2048 step %c64 {
1096-
%44 = arith.index_cast %arg6 : index to i32
1097-
%45 = tt.splat %44 : i32 -> tensor<1x64xi32, #blocked3>
1093+
scf.for %arg6 = %c0 to %c2048 step %c64 : i32 {
1094+
%45 = tt.splat %arg6 : i32 -> tensor<1x64xi32, #blocked3>
10981095
%46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3>
10991096
%47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3>
11001097
%48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3>
@@ -1226,9 +1223,9 @@ module attributes {"ttg.num-warps" = 2 : i32, "ttg.num-ctas" = 1 : i32} {
12261223
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
12271224
tt.func public @reduce_cvt2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) {
12281225
%cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked>
1229-
%c3136_i32 = arith.constant 3136 : index
1230-
%c256_i32 = arith.constant 256 : index
1231-
%c0_i32 = arith.constant 0 : index
1226+
%c3136_i32 = arith.constant 3136 : i32
1227+
%c256_i32 = arith.constant 256 : i32
1228+
%c0_i32 = arith.constant 0 : i32
12321229
%cst_0 = arith.constant dense<3.136000e+03> : tensor<1x1xf32, #blocked>
12331230
%cst_1 = arith.constant dense<50176> : tensor<1x256xi32, #blocked>
12341231
%cst_2 = arith.constant dense<196> : tensor<1x1xi32, #blocked>
@@ -1250,9 +1247,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
12501247
%12 = tt.broadcast %11 : tensor<1x1xi32, #blocked> -> tensor<1x256xi32, #blocked>
12511248
%13 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1x256x!tt.ptr<f32>, #blocked>
12521249
%14 = tt.broadcast %7 : tensor<1x1xi1, #blocked> -> tensor<1x256xi1, #blocked>
1253-
%15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) {
1254-
%42 = arith.index_cast %arg5 : index to i32
1255-
%43 = tt.splat %42 : i32 -> tensor<1x256xi32, #blocked>
1250+
%15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) : i32 {
1251+
%43 = tt.splat %arg5 : i32 -> tensor<1x256xi32, #blocked>
12561252
%44 = arith.addi %43, %10 : tensor<1x256xi32, #blocked>
12571253
%45 = arith.cmpi "slt", %44, %cst_4 : tensor<1x256xi32, #blocked>
12581254
%46 = arith.remsi %44, %cst_3 : tensor<1x256xi32, #blocked>

test/TritonGPU/matmul.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
module {
77
tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
88
%cst = arith.constant dense<true> : tensor<64x64xi1>
9-
%c64 = arith.constant 64 : index
10-
%c0 = arith.constant 0 : index
9+
%c64 = arith.constant 64 : i32
10+
%c0 = arith.constant 0 : i32
1111
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
1212
%c64_i32 = arith.constant 64 : i32
1313
%c63_i32 = arith.constant 63 : i32
@@ -58,8 +58,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1
5858
%43 = arith.addi %41, %42 : tensor<64x64xi32>
5959
%44 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
6060
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
61-
%46 = arith.index_cast %arg5 : i32 to index
62-
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
61+
%47:3 = scf.for %arg12 = %c0 to %arg5 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) : i32 {
6362
%76 = tt.load %arg14, %cst, %cst_0 : tensor<64x64x!tt.ptr<f32>>
6463
%77 = tt.load %arg15, %cst, %cst_0 : tensor<64x64x!tt.ptr<f32>>
6564
%78 = tt.dot %76, %77, %cst_0 : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>

third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ bool verifyNonNegativeExpr(Value expr, const DenseSet<Value> &assumptions) {
151151
.Case<triton::PtrToIntOp, triton::BitcastOp>(
152152
[&](auto) { return false; })
153153
.Case<arith::CeilDivUIOp, arith::DivUIOp, arith::ExtUIOp,
154-
arith::FPToUIOp, arith::IndexCastUIOp, arith::MaxUIOp,
155-
arith::MinUIOp, arith::RemUIOp, arith::ShRUIOp>(
154+
arith::FPToUIOp, arith::MaxUIOp, arith::MinUIOp, arith::RemUIOp,
155+
arith::ShRUIOp>(
156156
// These OPs also return unsigned values.
157157
// TODO: We can also sniff whether a Value is unsigned by looking
158158
// for whether or not it's used as an argument to one of

0 commit comments

Comments
 (0)