Skip to content

Commit fb19a56

Browse files
authored
Support causal flash attention (#2425)
This PR adds support for causal FA: - Keeps encoding on row-vector tensor operations, as must be left untouched when lowering to the SIMT program. - Extends the pattern matching helper that determines whether a tensor is transposed, to look through advance operations. (The second attention loop uses a transposed tensor pointer that is `tt.advance`'d between the loops.) --------- Signed-off-by: Julian Oppermann <[email protected]>
1 parent b12d0dd commit fb19a56

File tree

3 files changed

+134
-56
lines changed

3 files changed

+134
-56
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -269,23 +269,18 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
269269
quantiles=quantiles)
270270

271271
elif provider == 'triton':
272-
# FIXME: remove below if condition when extend attention support for Causal = True done
273-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1102
274-
if os.environ.get('TRITON_INTEL_ADVANCED_PATH', '0') == '1' and CAUSAL:
275-
min_ms, max_ms, mean, cv = (float('inf'), ) * 4
272+
triton_fn = lambda: forward(q, k, v, CAUSAL, sm_scale)
273+
if benchmark_suit.USE_IPEX_OPTION:
274+
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
275+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
276276
else:
277-
triton_fn = lambda: forward(q, k, v, CAUSAL, sm_scale)
278-
if benchmark_suit.USE_IPEX_OPTION:
279-
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
280-
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
281-
else:
282-
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
283-
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
284-
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
285-
atol = 1e-1 if N_CTX == 16384 else 1e-2
286-
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
287-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
288-
kernel_name='_attn_fwd')
277+
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
278+
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
279+
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
280+
atol = 1e-1 if N_CTX == 16384 else 1e-2
281+
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
282+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
283+
kernel_name='_attn_fwd')
289284

290285
elif provider == 'xetla':
291286
module_name = f'flash_attn_causal_{CAUSAL}'.lower()

test/TritonIntelGPU/match-target-size.mlir

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -537,14 +537,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
537537
// CHECK: %[[BC1:.*]] = triton_intel_gpu.broadcast %[[ED1]] : tensor<16x1xi32, #warp> -> tensor<16x16xi32>
538538
%4 = triton_intel_gpu.broadcast %2 : tensor<16x1xi32, #warp> -> tensor<16x64xi32, #warp>
539539

540-
// CHECK: %[[EX0:.*]] = triton_intel_gpu.extract %[[ED2]][0] : tensor<1x64xi32, #warp> -> tensor<1x16xi32>
541-
// CHECK: %[[BC20:.*]] = triton_intel_gpu.broadcast %[[EX0]] : tensor<1x16xi32> -> tensor<16x16xi32>
542-
// CHECK: %[[EX1:.*]] = triton_intel_gpu.extract %[[ED2]][1] : tensor<1x64xi32, #warp> -> tensor<1x16xi32>
543-
// CHECK: %[[BC21:.*]] = triton_intel_gpu.broadcast %[[EX1]] : tensor<1x16xi32> -> tensor<16x16xi32>
544-
// CHECK: %[[EX2:.*]] = triton_intel_gpu.extract %[[ED2]][2] : tensor<1x64xi32, #warp> -> tensor<1x16xi32>
545-
// CHECK: %[[BC22:.*]] = triton_intel_gpu.broadcast %[[EX2]] : tensor<1x16xi32> -> tensor<16x16xi32>
546-
// CHECK: %[[EX3:.*]] = triton_intel_gpu.extract %[[ED2]][3] : tensor<1x64xi32, #warp> -> tensor<1x16xi32>
547-
// CHECK: %[[BC23:.*]] = triton_intel_gpu.broadcast %[[EX3]] : tensor<1x16xi32> -> tensor<16x16xi32>
540+
// CHECK: %[[EX0:.*]] = triton_intel_gpu.extract %[[ED2]][0] : tensor<1x64xi32, #warp> -> tensor<1x16xi32, #warp>
541+
// CHECK: %[[BC20:.*]] = triton_intel_gpu.broadcast %[[EX0]] : tensor<1x16xi32, #warp> -> tensor<16x16xi32>
542+
// CHECK: %[[EX1:.*]] = triton_intel_gpu.extract %[[ED2]][1] : tensor<1x64xi32, #warp> -> tensor<1x16xi32, #warp>
543+
// CHECK: %[[BC21:.*]] = triton_intel_gpu.broadcast %[[EX1]] : tensor<1x16xi32, #warp> -> tensor<16x16xi32>
544+
// CHECK: %[[EX2:.*]] = triton_intel_gpu.extract %[[ED2]][2] : tensor<1x64xi32, #warp> -> tensor<1x16xi32, #warp>
545+
// CHECK: %[[BC22:.*]] = triton_intel_gpu.broadcast %[[EX2]] : tensor<1x16xi32, #warp> -> tensor<16x16xi32>
546+
// CHECK: %[[EX3:.*]] = triton_intel_gpu.extract %[[ED2]][3] : tensor<1x64xi32, #warp> -> tensor<1x16xi32, #warp>
547+
// CHECK: %[[BC23:.*]] = triton_intel_gpu.broadcast %[[EX3]] : tensor<1x16xi32, #warp> -> tensor<16x16xi32>
548548
%5 = triton_intel_gpu.broadcast %3 : tensor<1x64xi32, #warp> -> tensor<16x64xi32, #warp>
549549

550550
// CHECK: arith.addi %[[BC1]], %[[BC20]] : tensor<16x16xi32>
@@ -563,3 +563,70 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
563563
tt.return
564564
}
565565
}
566+
567+
// -----
568+
569+
// COM: This test checks that the tt.load/tt.advance ops in _both_ loops are detected as being transposed and hence having the 16x16 shape (would be 32x16 otherwise).
570+
571+
#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
572+
#warp1 = #triton_intel_gpu.warp<{sizePerThread = [16, 32], threadsPerWarp = [1, 1], order = [1, 0]}>
573+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
574+
tt.func public @_attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: f32, %arg4: !tt.ptr<f32>, %arg5: !tt.ptr<f32>) attributes {noinline = false} {
575+
%c16_i32 = arith.constant 16 : i32
576+
%c131072_i64 = arith.constant 131072 : i64
577+
%c65536_i64 = arith.constant 65536 : i64
578+
%c128_i32 = arith.constant 128 : i32
579+
%c1024_i64 = arith.constant 1024 : i64
580+
%c64_i64 = arith.constant 64 : i64
581+
%c1_i64 = arith.constant 1 : i64
582+
%c0_i32 = arith.constant 0 : i32
583+
%cst = arith.constant 1.44269502 : f32
584+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x64xf32, #warp>
585+
%c64_i32 = arith.constant 64 : i32
586+
%c1_i32 = arith.constant 1 : i32
587+
%0 = gpu.subgroup_id : index
588+
%1 = arith.index_cast %0 : index to i32
589+
%2 = tt.get_program_id z : i32
590+
%3 = tt.get_program_id x : i32
591+
%4 = tt.get_program_id y : i32
592+
%5 = arith.extsi %3 : i32 to i64
593+
%6 = arith.muli %5, %c131072_i64 : i64
594+
%7 = arith.extsi %4 : i32 to i64
595+
%8 = arith.muli %7, %c65536_i64 : i64
596+
%9 = arith.addi %6, %8 : i64
597+
%10 = tt.addptr %arg0, %9 : !tt.ptr<f16>, i64
598+
%11 = arith.muli %2, %c128_i32 : i32
599+
%12 = arith.muli %1, %c16_i32 : i32
600+
%13 = arith.addi %12, %11 : i32
601+
%14 = tt.make_tensor_ptr %10, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
602+
%28 = tt.addptr %arg1, %9 : !tt.ptr<f16>, i64
603+
%34 = tt.make_tensor_ptr %28, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
604+
%35 = tt.addptr %arg5, %9 : !tt.ptr<f32>, i64
605+
%36 = tt.make_tensor_ptr %35, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf32, #warp>>
606+
%44 = tt.load %14 : !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
607+
%47:2 = scf.for %arg6 = %c0_i32 to %11 step %c64_i32 iter_args(%arg7 = %cst_0, %arg11 = %34) -> (tensor<16x64xf32, #warp>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>) : i32 {
608+
// CHECK-COUNT-16: tt.load {{%.*}} {DotIdx = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
609+
%60 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
610+
%61 = tt.dot %44, %60, %cst_0, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<16x64xf32, #warp>
611+
// CHECK-COUNT-16: tt.advance {{%.*}}, [%c0_i32, %c64_i32] {DotIdx = 1 : i32} : <tensor<16x16xf16>>
612+
%85 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
613+
scf.yield %61, %85 : tensor<16x64xf32, #warp>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
614+
} {triton_gpu.workload = 4 : i32, tt.divisibility_arg1 = dense<64> : tensor<1xi32>}
615+
// CHECK: gpu.barrier
616+
gpu.barrier
617+
%48 = arith.muli %2, %c128_i32 {tt.divisibility = dense<128> : tensor<1xi32>} : i32
618+
%49 = arith.addi %2, %c1_i32 : i32
619+
%50 = arith.muli %49, %c128_i32 : i32
620+
%51 = tt.advance %34, [%c0_i32, %48] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
621+
%56:2 = scf.for %arg6 = %48 to %50 step %c64_i32 iter_args(%arg7 = %47#0, %arg11 = %51) -> (tensor<16x64xf32, #warp>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>) : i32 {
622+
// CHECK-COUNT-16: tt.load {{%.*}} {DotIdx = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
623+
%60 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
624+
%61 = tt.dot %44, %60, %cst_0, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<16x64xf32, #warp>
625+
// CHECK-COUNT-16: tt.advance {{%.*}}, [%c0_i32, %c64_i32] {DotIdx = 1 : i32} : <tensor<16x16xf16>>
626+
%88 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
627+
scf.yield %61, %88 : tensor<16x64xf32, #warp>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
628+
} {triton_gpu.workload = 4 : i32}
629+
tt.store %36, %56#0 : !tt.ptr<tensor<16x64xf32, #warp>>
630+
tt.return
631+
}
632+
}

third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ static tt::LoadOp findUsedLoad(Value val) {
149149
}
150150

151151
static bool getTransposeFlagFromValue(Value val) {
152-
bool isTransposed = false;
153152
Value loadPtr = val;
154153
// backward: from dot operands to tt.load
155154
if (llvm::any_of(val.getUsers(),
@@ -167,23 +166,26 @@ static bool getTransposeFlagFromValue(Value val) {
167166
if (auto blockArg = dyn_cast<BlockArgument>(loadPtr)) {
168167
unsigned argIdx = blockArg.getArgNumber();
169168
if (auto loopLikeOp = dyn_cast<LoopLikeOpInterface>(
170-
blockArg.getParentBlock()->getParentOp())) {
171-
auto inits = llvm::to_vector(loopLikeOp.getInits());
172-
if (auto glueOp = inits[argIdx - 1].getDefiningOp<ttgi::GlueOp>()) {
173-
if (auto tempPtr =
174-
glueOp.getOperands()[0].getDefiningOp<tt::MakeTensorPtrOp>()) {
175-
loadPtr = tempPtr.getResult();
176-
}
177-
}
178-
}
169+
blockArg.getParentBlock()->getParentOp()))
170+
loadPtr = loopLikeOp.getInits()[argIdx - 1];
171+
}
172+
173+
if (auto glueOp = loadPtr.getDefiningOp<ttgi::GlueOp>()) {
174+
if (isa_and_present<tt::MakeTensorPtrOp, tt::AdvanceOp>(
175+
glueOp.getOperands()[0].getDefiningOp()))
176+
loadPtr = glueOp.getOperands()[0];
179177
}
180178

181179
if (auto tensorPtr = loadPtr.getDefiningOp<tt::MakeTensorPtrOp>()) {
182180
ArrayRef<int32_t> order = tensorPtr.getOrder();
183181
auto rank = order.size();
184-
isTransposed = (order[rank - 2] != 1);
182+
return (order[rank - 2] != 1);
185183
}
186-
return isTransposed;
184+
185+
if (auto advOp = loadPtr.getDefiningOp<tt::AdvanceOp>())
186+
return getTransposeFlagFromValue(advOp.getPtr());
187+
188+
return false;
187189
}
188190

189191
static void rewriteLoadWithSLM(ModuleOp &m, DenseSet<Value> &dotWithSLMOperands,
@@ -275,6 +277,14 @@ class MatchTargetSizePass
275277
MLIRContext *ctx = &getContext();
276278
ModuleOp m = getOperation();
277279

280+
// By default, tritongpu are lowered to simt mode (threads-per-warp=16)
281+
// instead of simd mode (threads-per-warp=1).
282+
// FIXME: force threads-per-warp=16 in simt(this should be done via an
283+
// analysis designed to determine whether the kernel contains tt.dot
284+
// operations that use block pointers).
285+
m->setAttr("triton_gpu.threads-per-warp",
286+
IntegerAttr::get(IntegerType::get(ctx, 32), 16));
287+
278288
Workload workload = Workload::None;
279289
m.walk([&](scf::ForOp forOp) {
280290
if (Attribute attr = forOp->getAttr(AttrWorkloadName))
@@ -352,14 +362,6 @@ class MatchTargetSizePass
352362
canonicalize();
353363
LLVM_DEBUG(llvm::dbgs() << "Module after canonicalization:\n"
354364
<< m << "\n\n");
355-
356-
// By default, tritongpu are lowered to simt mode (threads-per-warp=16)
357-
// instead of simd mode (threads-per-warp=1).
358-
// FIXME: force threads-per-warp=16 in simt(this should be done via an
359-
// analysis designed to determine whether the kernel contains tt.dot
360-
// operations that use block pointers).
361-
m->setAttr("triton_gpu.threads-per-warp",
362-
IntegerAttr::get(IntegerType::get(ctx, 32), 16));
363365
}
364366

365367
private:
@@ -379,8 +381,8 @@ class MatchTargetSizePass
379381
bool isTransposed) const;
380382

381383
std::tuple<SmallVector<int64_t>, Type, SmallVector<int64_t>>
382-
getSubTypeAndShape(Type type, bool isTransposed = false,
383-
bool useSLM = false) const;
384+
getSubTypeAndShape(Type type, bool isTransposed = false, bool useSLM = false,
385+
bool keepEncoding = false) const;
384386

385387
Value getSubVal(Operation *op, Value val, ArrayRef<int64_t> srcOffset,
386388
ArrayRef<int64_t> dstSize);
@@ -753,7 +755,7 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type,
753755
/// return [shape, subType, subSize] for a tensor (or pointer to tensor)
754756
std::tuple<SmallVector<int64_t>, Type, SmallVector<int64_t>>
755757
MatchTargetSizePass::getSubTypeAndShape(Type type, bool isTransposed,
756-
bool useSLM) const {
758+
bool useSLM, bool keepEncoding) const {
757759
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
758760
Attribute layout = tensorType.getEncoding();
759761
assert(layout && "Expecting a valid layout");
@@ -771,15 +773,16 @@ MatchTargetSizePass::getSubTypeAndShape(Type type, bool isTransposed,
771773
subSize[1] = std::min(subSize[1], shape[1]);
772774
}
773775

774-
auto subType = RankedTensorType::get(
775-
subSize, tensorType.getElementType() /*no encoding*/);
776+
auto subType = RankedTensorType::get(subSize, tensorType.getElementType(),
777+
keepEncoding ? tensorType.getEncoding()
778+
: Attribute{});
776779
return {shape, subType, subSize};
777780
}
778781

779782
if (auto ptrType = dyn_cast<tt::PointerType>(type)) {
780783
Type pointeeType = ptrType.getPointeeType();
781784
auto [shape, subType, subSize] =
782-
getSubTypeAndShape(pointeeType, isTransposed, useSLM);
785+
getSubTypeAndShape(pointeeType, isTransposed, useSLM, keepEncoding);
783786
auto newType = tt::PointerType::get(subType, ptrType.getAddressSpace());
784787
return {shape, newType, subSize};
785788
}
@@ -1186,8 +1189,11 @@ void MatchTargetSizePass::transformBroadcastOp(ttgi::BroadcastOp op) {
11861189
glue = b.create<ttgi::GlueOp>(loc, resType, ops);
11871190
} else if (srcDim0 == 1 && srcDim1 == resDim1) {
11881191
// Handle row-vector broadcasts, e.g. 1x64 --> 16x64.
1192+
// This kind of broadcast requires that the tensor type is kept intact by
1193+
// SIMT lowering, hence propagate the encoding here.
11891194
auto subRowVecTy =
1190-
RankedTensorType::get({1, tType.getShape()[1]}, tType.getElementType());
1195+
RankedTensorType::get({1, tType.getShape()[1]}, tType.getElementType(),
1196+
srcType.getEncoding());
11911197

11921198
// How many extracts do we need to cover the width of the input tensor?
11931199
unsigned nExtracts = srcDim1 / dstDim1;
@@ -1222,9 +1228,10 @@ void MatchTargetSizePass::transformMakeRangeOp(tt::MakeRangeOp op) {
12221228

12231229
unsigned start = op.getStart();
12241230
unsigned end = op.getEnd();
1225-
assert(start == 0 && end % subgroupSize == 0 && "Unsupported range");
1231+
assert(start == 0 && (end <= subgroupSize || end % subgroupSize == 0) &&
1232+
"Unsupported range");
12261233

1227-
if (end == subgroupSize)
1234+
if (end <= subgroupSize)
12281235
// nothing to do
12291236
return;
12301237

@@ -1240,6 +1247,7 @@ void MatchTargetSizePass::transformMakeRangeOp(tt::MakeRangeOp op) {
12401247
Location loc = op.getLoc();
12411248
RankedTensorType origTy = op.getType();
12421249
Type elemTy = origTy.getElementType();
1250+
// Propagate encoding to keep tensor during SIMT lowering.
12431251
auto subRangeTy =
12441252
RankedTensorType::get({subgroupSize}, elemTy, origTy.getEncoding());
12451253
auto subRange = b.create<tt::MakeRangeOp>(loc, subRangeTy, 0, subgroupSize);
@@ -1310,8 +1318,16 @@ void MatchTargetSizePass::transformGenericOp(Operation *op) {
13101318
cast<tt::PointerType>(load.getPtr().getType()).getAddressSpace();
13111319
useSLM = (ptrAS == TritonGEN::TritonGENMemorySpace::kWorkgroup);
13121320
}
1321+
1322+
// Keep encoding on certain tensors to leave them untouched during SIMT
1323+
// lowering. Currently, this is required for "row vectors" (= `tensor<1xN>`).
1324+
bool keepEncoding = false;
1325+
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
1326+
ArrayRef<int64_t> shape = tensorType.getShape();
1327+
keepEncoding = shape.size() == 2 && shape[0] == 1 && shape[1] > 1;
1328+
}
13131329
auto [shape, subType, subSize] =
1314-
getSubTypeAndShape(type, isTransposed, useSLM);
1330+
getSubTypeAndShape(type, isTransposed, useSLM, keepEncoding);
13151331

13161332
unsigned dim = shape.size();
13171333
OpBuilder b(op);
@@ -1328,8 +1344,8 @@ void MatchTargetSizePass::transformGenericOp(Operation *op) {
13281344
[&](Value operand) {
13291345
Type type = operand.getType();
13301346
if (isa<tt::PointerType, RankedTensorType>(type)) {
1331-
Type subOpndType = std::get<1>(
1332-
getSubTypeAndShape(type, isTransposed, useSLM));
1347+
Type subOpndType = std::get<1>(getSubTypeAndShape(
1348+
type, isTransposed, useSLM, keepEncoding));
13331349
Value newOp = b.create<ttgi::ExtractOp>(
13341350
loc, subOpndType, operand, idx);
13351351
return newOp;

0 commit comments

Comments
 (0)