diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index c967e863554fc..d8c54ec5f88c3 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef operands) { } auto resultID = operands[1]; - if (auto shapedType = dyn_cast(resultType)) { + if (auto tensorType = dyn_cast(resultType)) { + SmallVector flattenedElems; + for (Attribute element : elements) { + if (auto denseElemAttr = dyn_cast(element)) { + for (auto value : denseElemAttr.getValues()) + flattenedElems.push_back(value); + } else { + flattenedElems.push_back(element); + } + } + auto attr = DenseElementsAttr::get(tensorType, flattenedElems); + constantMap.try_emplace(resultID, attr, tensorType); + } else if (auto shapedType = dyn_cast(resultType)) { auto attr = DenseElementsAttr::get(shapedType, elements); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index c049574fbc9e3..7c007de315589 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -956,6 +956,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; auto elementType = cast(constType).getElementType(0); + if (auto tensorArmType = dyn_cast(constType)) { + ArrayRef innerShape = tensorArmType.getShape().drop_front(); + if (!innerShape.empty()) + elementType = spirv::TensorArmType::get(innerShape, elementType); + } // "If the Result Type is a cooperative matrix type, then there must be only // one Constituent, with scalar type matching the cooperative matrix Component @@ -979,30 +984,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, } else { return 0; } - } else if (isa(constType)) { - if (isZeroValue(valueAttr)) { - encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, - {typeID, resultID}); - return resultID; - } - numberOfConstituents = shapedType.getNumElements(); - operands.reserve(numberOfConstituents + 2); - for (int i = 0; i < numberOfConstituents; ++i) { - uint32_t elementID = 0; - if (auto attr = dyn_cast(valueAttr)) { - elementID = - elementType.isInteger(1) - ? prepareConstantBool(loc, attr.getValues()[i]) - : prepareConstantInt(loc, attr.getValues()[i]); - } - if (auto attr = dyn_cast(valueAttr)) { - elementID = prepareConstantFp(loc, attr.getValues()[i]); - } - if (!elementID) { - return 0; - } - operands.push_back(elementID); - } + } else if (isa(constType) && isZeroValue(valueAttr)) { + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, + {typeID, resultID}); + return resultID; } else { operands.reserve(numberOfConstituents + 2); for (int i = 0; i < numberOfConstituents; ++i) { diff --git a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir index 275e586f70634..7fb8af1904388 100644 --- a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir +++ b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir @@ -1,17 +1,36 @@ // RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s -// DISABLED: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %} - -// FIXME(#152012): Fix arm tensor constant validation errors and reenable spirv-val tests. +// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %} spirv.module Logical Vulkan requires #spirv.vce { - // CHECK-LABEL: @arm_tensor_of_i32 - spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK-LABEL: @rank_1_arm_tensor_of_i32 + spirv.func @rank_1_arm_tensor_of_i32() -> (!spirv.arm.tensor<3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32> + %0 = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<3xi32> + } + + // CHECK-LABEL: @rank_2_arm_tensor_of_i32 + spirv.func @rank_2_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> } + // CHECK-LABEL: @rank_3_arm_tensor_of_i32 + spirv.func @rank_3_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x2x3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1, 2, 3], [4, 5, 6]], {{\[}}[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32> + %0 = spirv.Constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xi32> + } + + // CHECK-LABEL: @rank_4_arm_tensor_of_i32 + spirv.func @rank_4_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3x4x5xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32> + %0 = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xi32> + } + // CHECK-LABEL: @splat_arm_tensor_of_i32 spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> @@ -19,13 +38,34 @@ spirv.module Logical Vulkan requires #spirv.vce } - // CHECK-LABEL: @arm_tensor_of_f32 - spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK-LABEL: @rank_1_arm_tensor_of_f32 + spirv.func @rank_1_arm_tensor_of_f32() -> (!spirv.arm.tensor<3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : !spirv.arm.tensor<3xf32> + %0 = spirv.Constant dense<[1.0, 2.0, 3.0]> : !spirv.arm.tensor<3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32> + } + + // CHECK-LABEL: @rank_2_arm_tensor_of_f32 + spirv.func @rank_2_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32> - %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : !spirv.arm.tensor<2x3xf32> spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> } + // CHECK-LABEL: @rank_3_arm_tensor_of_f32 + spirv.func @rank_3_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x2x3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]], {{\[}}[7.000000e+00, 8.000000e+00, 9.000000e+00], [1.000000e+01, 1.100000e+01, 1.200000e+01]]]> : !spirv.arm.tensor<2x2x3xf32> + %0 = spirv.Constant dense<[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]> : !spirv.arm.tensor<2x2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xf32> + } + + // CHECK-LABEL: @rank_4_arm_tensor_of_f32 + spirv.func @rank_4_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3x4x5xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<5.000000e+00> : !spirv.arm.tensor<2x3x4x5xf32> + %0 = spirv.Constant dense<5.0> : !spirv.arm.tensor<2x3x4x5xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xf32> + } + // CHECK-LABEL: @splat_arm_tensor_of_f32 spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>