Skip to content

Commit b0e0f23

Browse files
committed
[mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstantNull
This patch enables (de)serialization to/from OpConstantNull for null TensorARM Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent 6752415 commit b0e0f23

File tree

3 files changed

+68
-7
lines changed

3 files changed

+68
-7
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,7 +1779,7 @@ LogicalResult
17791779
spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
17801780
if (operands.size() != 2) {
17811781
return emitError(unknownLoc,
1782-
"OpConstantNull must have type <id> and result <id>");
1782+
"OpConstantNull must only have type <id> and result <id>");
17831783
}
17841784

17851785
Type resultType = getType(operands[0]);
@@ -1789,8 +1789,17 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
17891789
}
17901790

17911791
auto resultID = operands[1];
1792+
Attribute attr;
17921793
if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1793-
auto attr = opBuilder.getZeroAttr(resultType);
1794+
attr = opBuilder.getZeroAttr(resultType);
1795+
} else if (isa<TensorArmType>(resultType)) {
1796+
auto shapedType = cast<ShapedType>(resultType);
1797+
auto element = opBuilder.getZeroAttr(shapedType.getElementType());
1798+
if (element)
1799+
attr = DenseElementsAttr::get(shapedType, element);
1800+
}
1801+
1802+
if (attr) {
17941803
// For normal constants, we just record the attribute (and its type) for
17951804
// later materialization at use sites.
17961805
constantMap.try_emplace(resultID, attr, resultType);

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ static Block *getPhiIncomingBlock(Block *block) {
6969
return block;
7070
}
7171

72+
static bool isNull(Attribute attr) {
73+
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
74+
return floatAttr.getValue().isZero();
75+
}
76+
if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
77+
return !boolAttr.getValue();
78+
}
79+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
80+
return intAttr.getValue().isZero();
81+
}
82+
if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
83+
return all_of(denseElemAttr.getValues<Attribute>(), isNull);
84+
}
85+
return false;
86+
}
87+
7288
namespace mlir {
7389
namespace spirv {
7490

@@ -959,6 +975,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
959975
return 0;
960976
}
961977
} else if (isa<spirv::TensorArmType>(constType)) {
978+
if (isNull(valueAttr)) {
979+
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
980+
{typeID, resultID});
981+
return resultID;
982+
}
962983
numberOfConstituents = shapedType.getNumElements();
963984
operands.reserve(numberOfConstituents + 2);
964985
for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1223,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
12021223
}
12031224

12041225
uint32_t resultID = getNextID();
1205-
uint32_t operands[] = {typeID, resultID, constandID};
1206-
1207-
encodeInstructionInto(typesGlobalValues,
1208-
spirv::Opcode::OpConstantCompositeReplicateEXT,
1209-
operands);
1226+
if (dyn_cast<spirv::TensorArmType>(resultType) && isNull(valueAttr)) {
1227+
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1228+
{typeID, resultID});
1229+
} else {
1230+
encodeInstructionInto(typesGlobalValues,
1231+
spirv::Opcode::OpConstantCompositeReplicateEXT,
1232+
{typeID, resultID, constandID});
1233+
}
12101234

12111235
constCompositeReplicateIDMap[valueTypePair] = resultID;
12121236
return resultID;

mlir/test/Target/SPIRV/constant.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
335335
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
336336
}
337337

338+
// CHECK-LABEL: @null_arm_tensor_of_i32
339+
spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
340+
// CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
341+
%0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
342+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
343+
}
344+
345+
// CHECK-LABEL: @null_arm_tensor_of_f32
346+
spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
347+
// CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
348+
%0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
349+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
350+
}
351+
338352
spirv.EntryPoint "GLCompute" @bool_const
339353
}
340354

@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
391405
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
392406
}
393407

408+
// CHECK-LABEL: @null_cc_arm_tensor_of_i32
409+
spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
410+
// CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
411+
%0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
412+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
413+
}
414+
394415
// CHECK-LABEL: @splat_vector_f32
395416
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
396417
// CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
439460
%0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
440461
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
441462
}
463+
464+
// CHECK-LABEL: @null_cc_arm_tensor_of_f32
465+
spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
466+
// CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
467+
%0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
468+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
469+
}
442470
}

0 commit comments

Comments
 (0)