Skip to content

Commit d4f100c

Browse files
authored
[Codegen] Materialize 0D set_encoding into no-op (#21418)
We can have set_encoding ops 0D tensors if encodings are propagated through generic ops with 0D tensor operands. If these operands are materialized normally, then 0D transpose ops get generated, which breaks the assumption of a non-0D permutation and causes an assertion error. This PR just materializes 0D tensors into a no-op, since they are just scalars, and are not affected by data tiling layouts. Signed-off-by: Max Dawkins <[email protected]>
1 parent 936f5da commit d4f100c

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,10 @@ struct SetEncodingOpLoweringConversion
680680
LogicalResult
681681
matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, OpAdaptor adaptor,
682682
ConversionPatternRewriter &rewriter) const override {
683+
if (encodingOp.getSource().getType().getRank() == 0) {
684+
rewriter.replaceOp(encodingOp, adaptor.getSource());
685+
return success();
686+
}
683687
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
684688
getTypeConverter());
685689
auto packedValue = lowerSetEncodingOpToPackOp(

compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_gfx942.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,3 +1314,24 @@ func.func @dequantization() {
13141314
// CHECK: arith.subf
13151315
// CHECK: arith.mulf
13161316
// CHECK: iree_tensor_ext.dispatch.tensor.store %[[LHS_DEQUANT]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [2, 1, 4, 8, 4, 4, 4, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : tensor<2x1x4x8x4x4x4x4xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x1x4x8x4x4x4x4xf32>>
1317+
1318+
// -----
1319+
1320+
#encoding = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f16, f16, f32],
1321+
user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1) -> ()>]]>
1322+
#pipeline_layout = #hal.pipeline.layout<bindings = [
1323+
#hal.pipeline.binding<storage_buffer>,
1324+
#hal.pipeline.binding<storage_buffer>
1325+
]>
1326+
func.func @set_encoding_0D_tensor() {
1327+
%c0 = arith.constant 0 : index
1328+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>>
1329+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<f32, #encoding>>
1330+
%2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32>
1331+
%3 = iree_encoding.set_encoding %2 : tensor<f32> -> tensor<f32, #encoding>
1332+
iree_tensor_ext.dispatch.tensor.store %3, %1, offsets = [], sizes = [], strides = [] : tensor<f32, #encoding> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<f32, #encoding>>
1333+
return
1334+
}
1335+
// CHECK-LABEL: func.func @set_encoding_0D_tensor()
1336+
// CHECK: %[[INPUT:.+]] = iree_tensor_ext.dispatch.tensor.load
1337+
// CHECK: iree_tensor_ext.dispatch.tensor.store %[[INPUT]]

0 commit comments

Comments
 (0)