Skip to content

Commit d6ef579

Browse files
authored
spirv-opt: add support for tensors to type manager (KhronosGroup#6202)
As specified by SPV_ARM_tensors. Change-Id: I229ea1a3232eb04a4124363d1641218acb298ac4 Signed-off-by: Kevin Petit <[email protected]>
1 parent e16fcd1 commit d6ef579

File tree

4 files changed

+139
-0
lines changed

4 files changed

+139
-0
lines changed

source/opt/type_manager.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,36 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
495495
{SPV_OPERAND_TYPE_ID, {coop_vec->components()}}});
496496
break;
497497
}
498+
case Type::kTensorARM: {
499+
auto tensor_type = type->AsTensorARM();
500+
uint32_t const element_type =
501+
GetTypeInstruction(tensor_type->element_type());
502+
if (element_type == 0) {
503+
return 0;
504+
}
505+
if (tensor_type->rank_id() != 0) {
506+
if (tensor_type->shape_id() != 0) {
507+
typeInst = MakeUnique<Instruction>(
508+
context(), spv::Op::OpTypeTensorARM, 0, id,
509+
std::initializer_list<Operand>{
510+
{SPV_OPERAND_TYPE_ID, {element_type}},
511+
{SPV_OPERAND_TYPE_ID, {tensor_type->rank_id()}},
512+
{SPV_OPERAND_TYPE_ID, {tensor_type->shape_id()}}});
513+
} else {
514+
typeInst = MakeUnique<Instruction>(
515+
context(), spv::Op::OpTypeTensorARM, 0, id,
516+
std::initializer_list<Operand>{
517+
{SPV_OPERAND_TYPE_ID, {element_type}},
518+
{SPV_OPERAND_TYPE_ID, {tensor_type->rank_id()}}});
519+
}
520+
} else {
521+
typeInst =
522+
MakeUnique<Instruction>(context(), spv::Op::OpTypeTensorARM, 0, id,
523+
std::initializer_list<Operand>{
524+
{SPV_OPERAND_TYPE_ID, {element_type}}});
525+
}
526+
break;
527+
}
498528
default:
499529
assert(false && "Unexpected type");
500530
break;
@@ -754,6 +784,14 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
754784
cv_type->components());
755785
break;
756786
}
787+
case Type::kTensorARM: {
788+
const TensorARM* tensor_type = type.AsTensorARM();
789+
const Type* element_type = tensor_type->element_type();
790+
rebuilt_ty = MakeUnique<TensorARM>(
791+
RebuildType(GetId(element_type), *element_type),
792+
tensor_type->rank_id(), tensor_type->shape_id());
793+
break;
794+
}
757795
default:
758796
assert(false && "Unhandled type");
759797
return nullptr;
@@ -1036,6 +1074,23 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
10361074
inst.GetSingleWordInOperand(1), perm);
10371075
break;
10381076
}
1077+
case spv::Op::OpTypeTensorARM: {
1078+
switch (inst.NumInOperands()) {
1079+
case 1:
1080+
type = new TensorARM(GetType(inst.GetSingleWordInOperand(0)));
1081+
break;
1082+
case 2:
1083+
type = new TensorARM(GetType(inst.GetSingleWordInOperand(0)),
1084+
inst.GetSingleWordInOperand(1));
1085+
break;
1086+
case 3:
1087+
type = new TensorARM(GetType(inst.GetSingleWordInOperand(0)),
1088+
inst.GetSingleWordInOperand(1),
1089+
inst.GetSingleWordInOperand(2));
1090+
break;
1091+
}
1092+
break;
1093+
}
10391094
default:
10401095
assert(false && "Type not handled by the type manager.");
10411096
break;

source/opt/types.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ std::unique_ptr<Type> Type::Clone() const {
135135
DeclareKindCase(CooperativeVectorNV);
136136
DeclareKindCase(RayQueryKHR);
137137
DeclareKindCase(HitObjectNV);
138+
DeclareKindCase(TensorARM);
138139
#undef DeclareKindCase
139140
default:
140141
assert(false && "Unhandled type");
@@ -187,6 +188,7 @@ bool Type::operator==(const Type& other) const {
187188
DeclareKindCase(HitObjectNV);
188189
DeclareKindCase(TensorLayoutNV);
189190
DeclareKindCase(TensorViewNV);
191+
DeclareKindCase(TensorARM);
190192
#undef DeclareKindCase
191193
default:
192194
assert(false && "Unhandled type");
@@ -247,6 +249,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
247249
DeclareKindCase(HitObjectNV);
248250
DeclareKindCase(TensorLayoutNV);
249251
DeclareKindCase(TensorViewNV);
252+
DeclareKindCase(TensorARM);
250253
#undef DeclareKindCase
251254
default:
252255
assert(false && "Unhandled type");
@@ -899,6 +902,36 @@ bool CooperativeVectorNV::IsSameImpl(const Type* that,
899902
components_ == mt->components_ && HasSameDecorations(that);
900903
}
901904

905+
TensorARM::TensorARM(const Type* elty, const uint32_t rank,
906+
const uint32_t shape)
907+
: Type(kTensorARM), element_type_(elty), rank_id_(rank), shape_id_(shape) {
908+
assert(elty != nullptr);
909+
if (shape != 0) {
910+
assert(rank != 0);
911+
}
912+
}
913+
914+
std::string TensorARM::str() const {
915+
std::ostringstream oss;
916+
oss << "tensor<" << element_type_->str() << ", id(" << rank_id_ << "), id("
917+
<< shape_id_ << ")>";
918+
return oss.str();
919+
}
920+
921+
size_t TensorARM::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
922+
hash = hash_combine(hash, rank_id_);
923+
hash = hash_combine(hash, shape_id_);
924+
return element_type_->ComputeHashValue(hash, seen);
925+
}
926+
927+
bool TensorARM::IsSameImpl(const Type* that, IsSameCache* seen) const {
928+
const TensorARM* tt = that->AsTensorARM();
929+
if (!tt) return false;
930+
return element_type_->IsSameImpl(tt->element_type_, seen) &&
931+
rank_id_ == tt->rank_id_ && shape_id_ == tt->shape_id_ &&
932+
HasSameDecorations(that);
933+
}
934+
902935
} // namespace analysis
903936
} // namespace opt
904937
} // namespace spvtools

source/opt/types.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class RayQueryKHR;
6969
class HitObjectNV;
7070
class TensorLayoutNV;
7171
class TensorViewNV;
72+
class TensorARM;
7273

7374
// Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
7475
// which is used as a way to probe the actual <subclass>.
@@ -114,6 +115,7 @@ class Type {
114115
kHitObjectNV,
115116
kTensorLayoutNV,
116117
kTensorViewNV,
118+
kTensorARM,
117119
kLast
118120
};
119121

@@ -220,6 +222,7 @@ class Type {
220222
DeclareCastMethod(HitObjectNV)
221223
DeclareCastMethod(TensorLayoutNV)
222224
DeclareCastMethod(TensorViewNV)
225+
DeclareCastMethod(TensorARM)
223226
#undef DeclareCastMethod
224227

225228
protected:
@@ -774,6 +777,31 @@ class CooperativeVectorNV : public Type {
774777
const uint32_t components_;
775778
};
776779

780+
class TensorARM : public Type {
781+
public:
782+
TensorARM(const Type* elty, const uint32_t rank = 0,
783+
const uint32_t shape = 0);
784+
TensorARM(const TensorARM&) = default;
785+
786+
std::string str() const override;
787+
788+
TensorARM* AsTensorARM() override { return this; }
789+
const TensorARM* AsTensorARM() const override { return this; }
790+
791+
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
792+
793+
const Type* element_type() const { return element_type_; }
794+
uint32_t rank_id() const { return rank_id_; }
795+
uint32_t shape_id() const { return shape_id_; }
796+
797+
private:
798+
bool IsSameImpl(const Type* that, IsSameCache*) const override;
799+
800+
const Type* element_type_;
801+
const uint32_t rank_id_;
802+
const uint32_t shape_id_;
803+
};
804+
777805
#define DefineParameterlessType(type, name) \
778806
class type : public Type { \
779807
public: \

test/opt/type_manager_test.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
182182
// SPV_AMDX_shader_enqueue
183183
types.emplace_back(new NodePayloadArrayAMDX(sts32f32));
184184

185+
// Tensors
186+
types.emplace_back(new TensorARM(f32));
187+
types.emplace_back(new TensorARM(f32, 4));
188+
types.emplace_back(new TensorARM(f32, 4, 44));
189+
185190
types.emplace_back(new TensorLayoutNV(1002, 1000));
186191
types.emplace_back(new TensorViewNV(1002, 1003, {1000, 1001}));
187192

@@ -251,6 +256,11 @@ TEST(TypeManager, TypeStrings) {
251256
%id2 = OpConstant %u32 2
252257
%cmkhr = OpTypeCooperativeMatrixKHR %f64 %id4 %id4 %id4 %id2
253258
%untyped = OpTypeUntypedPointerKHR Uniform
259+
; ID 43
260+
%ts_shape = OpConstantComposite %a5u32 %id4 %id4 %id4 %id4
261+
%ts = OpTypeTensorARM %u32
262+
%tsr = OpTypeTensorARM %u32 %id4
263+
%tss = OpTypeTensorARM %u32 %id4 %ts_shape
254264
)";
255265

256266
std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
@@ -291,6 +301,10 @@ TEST(TypeManager, TypeStrings) {
291301
{39, "<float64, 6, 6, 6>"},
292302
{41, "<float64, 6, 6, 6, 40>"},
293303
{42, "untyped_ptr 2*"}, // Include storage class number
304+
// Id 43 is OpConstantComposite %a5u32 %id4 %id4 %id4 %id4
305+
{44, "tensor<uint32, id(0), id(0)>"},
306+
{45, "tensor<uint32, id(6), id(0)>"},
307+
{46, "tensor<uint32, id(6), id(43)>"},
294308
};
295309

296310
std::unique_ptr<IRContext> context =
@@ -1049,8 +1063,11 @@ TEST(TypeManager, GetTypeInstructionAllTypes) {
10491063
; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]]
10501064
; CHECK: [[uint2:%\w+]] = OpConstant [[uint]] 2
10511065
; CHECK: [[uint8:%\w+]] = OpConstant [[uint]] 8
1066+
; CHECK: [[uint4:%\w+]] = OpConstant [[uint]] 4
1067+
; CHECK: [[uint_arr4:%\w+]] = OpTypeArray [[uint]] [[uint4]]
10521068
; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24
10531069
; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42
1070+
; CHECK: [[uint_arr4_44:%\w+]] = OpConstantComposite [[uint_arr4]] [[uint4]] [[uint4]] [[uint4]] [[uint4]]
10541071
; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100
10551072
; CHECK: [[void:%\w+]] = OpTypeVoid
10561073
; CHECK: [[bool:%\w+]] = OpTypeBool
@@ -1107,6 +1124,9 @@ TEST(TypeManager, GetTypeInstructionAllTypes) {
11071124
; CHECK: OpTypeCooperativeMatrixKHR [[f32]] [[uint8]] [[uint8]] [[uint8]] [[uint2]]
11081125
; CHECK: OpTypeRayQueryKHR
11091126
; CHECK: OpTypeHitObjectNV
1127+
; CHECK: OpTypeTensorARM [[f32]]
1128+
; CHECK: OpTypeTensorARM [[f32]] [[uint4]]
1129+
; CHECK: OpTypeTensorARM [[f32]] [[uint4]] [[uint_arr4_44]]
11101130
OpCapability Shader
11111131
OpCapability Int64
11121132
OpCapability Linkage
@@ -1118,8 +1138,11 @@ OpMemoryModel Logical GLSL450
11181138
%1001 = OpConstant %uint 1
11191139
%1002 = OpConstant %uint 2
11201140
%8 = OpConstant %uint 8
1141+
%4 = OpConstant %uint 4
1142+
%5 = OpTypeArray %uint %4
11211143
%24 = OpConstant %uint 24
11221144
%42 = OpConstant %uint 42
1145+
%44 = OpConstantComposite %5 %4 %4 %4 %4
11231146
%100 = OpConstant %uint 100
11241147
%1003 = OpConstantFalse %bool
11251148
)";

0 commit comments

Comments
 (0)