diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 52c672a05fa43..f99339852824c 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -767,19 +767,25 @@ void mlir::spirv::AddressOfOp::getAsmResultNames( // spirv.EXTConstantCompositeReplicate //===----------------------------------------------------------------------===// +// Returns type of attribute. In case of a TypedAttr this will simply return +// the type. But for an ArrayAttr which is untyped and can be multidimensional +// it creates the ArrayType recursively. +static Type getValueType(Attribute attr) { + if (auto typedAttr = dyn_cast(attr)) { + return typedAttr.getType(); + } + + if (auto arrayAttr = dyn_cast(attr)) { + return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size()); + } + + return nullptr; +} + LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() { - Type valueType; - if (auto typedAttr = dyn_cast(getValue())) { - valueType = typedAttr.getType(); - } else if (auto arrayAttr = dyn_cast(getValue())) { - auto typedElemAttr = dyn_cast(arrayAttr[0]); - if (!typedElemAttr) - return emitError("value attribute is not typed"); - valueType = - spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size()); - } else { + Type valueType = getValueType(getValue()); + if (!valueType) return emitError("unknown value attribute type"); - } auto compositeType = dyn_cast(getType()); if (!compositeType) diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 59665ec1add54..30536638b56f7 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -1187,6 +1187,21 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, return resultID; } +// Returns type of attribute. In case of a TypedAttr this will simply return +// the type. But for an ArrayAttr which is untyped and can be multidimensional +// it creates the ArrayType recursively. +static Type getValueType(Attribute attr) { + if (auto typedAttr = dyn_cast(attr)) { + return typedAttr.getType(); + } + + if (auto arrayAttr = dyn_cast(attr)) { + return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size()); + } + + return nullptr; +} + uint32_t Serializer::prepareConstantCompositeReplicate(Location loc, Type resultType, Attribute valueAttr) { @@ -1200,18 +1215,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc, return 0; } - Type valueType; - if (auto typedAttr = dyn_cast(valueAttr)) { - valueType = typedAttr.getType(); - } else if (auto arrayAttr = dyn_cast(valueAttr)) { - auto typedElemAttr = dyn_cast(arrayAttr[0]); - if (!typedElemAttr) - return 0; - valueType = - spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size()); - } else { + Type valueType = getValueType(valueAttr); + if (!valueAttr) return 0; - } auto compositeType = dyn_cast(resultType); if (!compositeType) diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index 3be49eefcaebf..c81ceac072bd0 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -405,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce } + // CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_i32 + spirv.func @splat_array_of_non_splat_array_of_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + %0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + } + // CHECK-LABEL: @null_cc_arm_tensor_of_i32 spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32> @@ -461,6 +468,13 @@ spirv.module Logical GLSL450 requires #spirv.vce } + // CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_f32 + spirv.func @splat_array_of_non_splat_array_of_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + %0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + } + // CHECK-LABEL: @null_cc_arm_tensor_of_f32 spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>