@@ -69,7 +69,7 @@ static Block *getPhiIncomingBlock(Block *block) {
6969 return block;
7070}
7171
72- static bool isNull (Attribute attr) {
72+ static bool isZeroValue (Attribute attr) {
7373 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
7474 return floatAttr.getValue ().isZero ();
7575 }
@@ -79,8 +79,11 @@ static bool isNull(Attribute attr) {
7979 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
8080 return intAttr.getValue ().isZero ();
8181 }
82+ if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
83+ return isZeroValue (splatElemAttr.getSplatValue <Attribute>());
84+ }
8285 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
83- return all_of (denseElemAttr.getValues <Attribute>(), isNull );
86+ return all_of (denseElemAttr.getValues <Attribute>(), isZeroValue );
8487 }
8588 return false ;
8689}
@@ -975,7 +978,7 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
975978 return 0 ;
976979 }
977980 } else if (isa<spirv::TensorArmType>(constType)) {
978- if (isNull (valueAttr)) {
981+ if (isZeroValue (valueAttr)) {
979982 encodeInstructionInto (typesGlobalValues, spirv::Opcode::OpConstantNull,
980983 {typeID, resultID});
981984 return resultID;
@@ -1223,7 +1226,7 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
12231226 }
12241227
12251228 uint32_t resultID = getNextID ();
1226- if (dyn_cast<spirv::TensorArmType>(resultType) && isNull (valueAttr)) {
1229+ if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue (valueAttr)) {
12271230 encodeInstructionInto (typesGlobalValues, spirv::Opcode::OpConstantNull,
12281231 {typeID, resultID});
12291232 } else {
0 commit comments