Skip to content

Commit 2964b06

Browse files
committed
Addressing code review comments
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent b0e0f23 commit 2964b06

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,11 +1792,9 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
17921792
Attribute attr;
17931793
if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
17941794
attr = opBuilder.getZeroAttr(resultType);
1795-
} else if (isa<TensorArmType>(resultType)) {
1796-
auto shapedType = cast<ShapedType>(resultType);
1797-
auto element = opBuilder.getZeroAttr(shapedType.getElementType());
1798-
if (element)
1799-
attr = DenseElementsAttr::get(shapedType, element);
1795+
} else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1796+
if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
1797+
attr = DenseElementsAttr::get(tensorType, element);
18001798
}
18011799

18021800
if (attr) {

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ static Block *getPhiIncomingBlock(Block *block) {
6969
return block;
7070
}
7171

72-
static bool isNull(Attribute attr) {
72+
static bool isZeroValue(Attribute attr) {
7373
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
7474
return floatAttr.getValue().isZero();
7575
}
@@ -79,8 +79,11 @@ static bool isNull(Attribute attr) {
7979
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
8080
return intAttr.getValue().isZero();
8181
}
82+
if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
83+
return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
84+
}
8285
if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
83-
return all_of(denseElemAttr.getValues<Attribute>(), isNull);
86+
return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
8487
}
8588
return false;
8689
}
@@ -975,7 +978,7 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
975978
return 0;
976979
}
977980
} else if (isa<spirv::TensorArmType>(constType)) {
978-
if (isNull(valueAttr)) {
981+
if (isZeroValue(valueAttr)) {
979982
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
980983
{typeID, resultID});
981984
return resultID;
@@ -1223,7 +1226,7 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
12231226
}
12241227

12251228
uint32_t resultID = getNextID();
1226-
if (dyn_cast<spirv::TensorArmType>(resultType) && isNull(valueAttr)) {
1229+
if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
12271230
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
12281231
{typeID, resultID});
12291232
} else {

0 commit comments

Comments
 (0)