Skip to content

Commit 1c7b3a1

Browse files
jtuylskeshavvinayak01
authored andcommitted
[Encoding] Add new identity encoding attribute (iree-org#21258)
Currently, the padding encoding attribute with sizes all zero is used to represent the identity encoding. This PR replaces that with an explicit identity encoding attribute which avoids the confusion of padding attributes showing up in non-padding related data-tiling pipelines. Signed-off-by: Jorn Tuyls <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent ebbbe61 commit 1c7b3a1

File tree

4 files changed

+66
-5
lines changed

4 files changed

+66
-5
lines changed

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,39 @@ LogicalResult PaddingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
496496
return success();
497497
}
498498

499+
//===---------------------------------------------------------------------===//
500+
// iree_encoding.identity
501+
//===---------------------------------------------------------------------===//
502+
503+
Value IdentityAttr::calculateStorageSizeInBytes(Location loc,
504+
OpBuilder &builder,
505+
RankedTensorType type,
506+
ValueRange dynamicDims) const {
507+
const int64_t elementSize = getRoundedElementByteWidth(type.getElementType());
508+
int64_t staticProduct = elementSize;
509+
Value dynamicProduct = builder.create<arith::ConstantIndexOp>(loc, 1);
510+
511+
size_t dynamicDimIdx = 0;
512+
for (int64_t dimSize : type.getShape()) {
513+
if (!ShapedType::isDynamic(dimSize)) {
514+
staticProduct *= dimSize;
515+
continue;
516+
}
517+
518+
Value dynamicDimSize = dynamicDims[dynamicDimIdx];
519+
++dynamicDimIdx;
520+
dynamicProduct = builder.createOrFold<arith::MulIOp>(
521+
loc, dynamicProduct, dynamicDimSize, arith::IntegerOverflowFlags::nsw);
522+
}
523+
return builder.createOrFold<arith::MulIOp>(
524+
loc, builder.create<arith::ConstantIndexOp>(loc, staticProduct),
525+
dynamicProduct, arith::IntegerOverflowFlags::nsw);
526+
}
527+
528+
bool IdentityAttr::isIdentityLayout() const { return true; }
529+
530+
bool IdentityAttr::isSerialized() const { return true; }
531+
499532
//===---------------------------------------------------------------------===//
500533
// iree_encoding.identity_resolver
501534
//===---------------------------------------------------------------------===//
@@ -506,9 +539,7 @@ IdentityResolverAttr::cloneWithSimplifiedConfig(DictionaryAttr) const {
506539
}
507540

508541
Attribute IdentityResolverAttr::getLayout(RankedTensorType type) const {
509-
MLIRContext *ctx = getContext();
510-
SmallVector<int64_t> zeros(type.getRank(), 0);
511-
return Encoding::PaddingAttr::get(ctx, DenseI64ArrayAttr::get(ctx, zeros));
542+
return Encoding::IdentityAttr::get(getContext());
512543
}
513544

514545
//===---------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,28 @@ def PaddingAttr : IREEEncoding_Attr<"Padding", [
297297
let genVerifyDecl = 1;
298298
}
299299

300+
//===---------------------------------------------------------------------===//
301+
// iree_encoding.identity
302+
//===---------------------------------------------------------------------===//
303+
304+
def IdentityAttr : IREEEncoding_Attr<"Identity", [
305+
DeclareAttrInterfaceMethods<IREEEncoding_SerializableAttr, [
306+
"calculateStorageSizeInBytes",
307+
"isIdentityLayout",
308+
"isSerialized"
309+
]>
310+
]> {
311+
let mnemonic = "identity";
312+
313+
let summary = [{The identity encoding.}];
314+
let description = [{
315+
An encoding attribute that represents the identity function on a type, i.e.
316+
it represents the same type as if there was no encoding.
317+
}];
318+
319+
let genVerifyDecl = 0;
320+
}
321+
300322
//===---------------------------------------------------------------------===//
301323
// iree_encoding.identity_resolver
302324
//===---------------------------------------------------------------------===//

compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,11 @@ func.func @dynamic_layout_encoding(%arg0: tensor<?x?x?xf32, #encoding>) -> tenso
262262
}
263263
// CHECK: #[[ENCODING:.+]] = #iree_encoding.layout<[#iree_encoding.padding<[0, ?, 64]>]>
264264
// CHECK: func.func @dynamic_layout_encoding(%[[ARG0:.+]]: tensor<?x?x?xf32, #[[ENCODING]]>)
265+
266+
// -----
267+
268+
#encoding = #iree_encoding.identity
269+
func.func @identity_encoding(%arg0: tensor<?x?xf32, #encoding>) -> tensor<?x?xf32, #encoding> {
270+
return %arg0 : tensor<?x?xf32, #encoding>
271+
}
272+
// CHECK: func.func @identity_encoding(%[[ARG0:.+]]: tensor<?x?xf32, #iree_encoding.identity>

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ util.func public @drop_encoding(%arg0: index, %arg1: index, %scalar_f32 : f32) {
459459
%0 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<?x0xf32, #encoding>{%arg0} in !stream.resource<*>{%arg1}
460460
util.return
461461
}
462-
// CHECK-DAG: #[[$IDENTITY_ENCODING:.+]] = #iree_encoding.testing<[#iree_encoding.padding<[0, 0]>]>
462+
// CHECK-DAG: #[[$IDENTITY_ENCODING:.+]] = #iree_encoding.testing<[#iree_encoding.identity]>
463463
// CHECK-LABEL: util.func public @drop_encoding
464464
// CHECK: stream.tensor.empty {{.+}} : tensor<?x0xf32, #[[$IDENTITY_ENCODING]]>
465465

@@ -476,7 +476,7 @@ util.func public @ignore_encoding_by_identity_resolver(%arg0: index, %arg1: inde
476476
%0 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<?x0xf32, #encoding>{%arg0} in !stream.resource<*>{%arg1}
477477
util.return
478478
}
479-
// CHECK-DAG: #[[$IDENTITY_ENCODING:.+]] = #iree_encoding.testing<[#iree_encoding.padding<[0, 0]>]>
479+
// CHECK-DAG: #[[$IDENTITY_ENCODING:.+]] = #iree_encoding.testing<[#iree_encoding.identity]>
480480
// CHECK-LABEL: util.func public @ignore_encoding_by_identity_resolver
481481
// CHECK: stream.tensor.empty {{.+}} : tensor<?x0xf32, #[[$IDENTITY_ENCODING]]>
482482

0 commit comments

Comments
 (0)