Skip to content

Commit 041ec1b

Browse files
authored
[Gluon] Fix splat returning auto encoding (#7490)
Previously we propagated the src encoding to all operands, even if it isn't a tensor. This lead to errors when using splat op returning an auto encoding.
1 parent 1053fca commit 041ec1b

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,12 @@ LogicalResult inferAutoLayouts(FuncOp func) {
121121
} else {
122122
auto srcEncoding = inferSrcEncoding(definingOp, enc);
123123
if (srcEncoding) {
124-
if (failed(updateEncoding(
125-
llvm::to_vector_of<Value>(definingOp->getOperands()),
126-
srcEncoding)))
124+
llvm::SmallVector<Value> tensorOperands;
125+
for (auto operand : definingOp->getOperands())
126+
if (isa<RankedTensorType>(operand.getType()))
127+
tensorOperands.push_back(operand);
128+
129+
if (failed(updateEncoding(tensorOperands, srcEncoding)))
127130
return failure();
128131
}
129132
}

test/Gluon/auto_encoding.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,23 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
9696
tt.return %cvt : tensor<32xi32, #blocked>
9797
}
9898
}
99+
100+
101+
// -----
102+
103+
104+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
105+
106+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
107+
tt.func public @infer_make_range() -> tensor<16xi32, #blocked> {
108+
// CHECK-DAG: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
109+
// CHECK: [[CST:%.*]] = arith.constant 0 : i32
110+
// CHECK: [[SPLAT: %.*]] = tt.splat [[CST]] : i32 -> tensor<16xi32, [[BLOCKED]]>
111+
// CHECK: [[RES:%.*]] = ttg.convert_layout [[RANGE]] : tensor<16xi32, [[BLOCKED]]> -> tensor<16xi32, [[BLOCKED]]>
112+
// CHECK: tt.return [[RES]] : tensor<16xi32, [[BLOCKED]]>
113+
%cst = arith.constant 0 : i32
114+
%0 = tt.splat %cst : i32 -> tensor<16xi32, #gluon.auto_encoding>
115+
%cvt = ttg.convert_layout %0 : tensor<16xi32, #gluon.auto_encoding> -> tensor<16xi32, #blocked>
116+
tt.return %cvt : tensor<16xi32, #blocked>
117+
}
118+
}

0 commit comments

Comments
 (0)