Skip to content

Commit 0cfb1fd

Browse files
authored
spirv-opt: add support for OpTypeGraphARM to type manager (KhronosGroup#6247)
* spirv-opt: add support for OpTypeGraphARM to type manager Signed-off-by: Kevin Petit <[email protected]> Change-Id: Ibfa48eaf265bdab7b9d15c5adb9a3b8d62e5840c * improve one test Change-Id: I3dedee7a790dfb89869bf8df6bb8e5b16a6f1c8d --------- Signed-off-by: Kevin Petit <[email protected]>
1 parent bf98dd7 commit 0cfb1fd

File tree

4 files changed

+138
-3
lines changed

4 files changed

+138
-3
lines changed

source/opt/type_manager.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,19 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
525525
}
526526
break;
527527
}
528+
case Type::kGraphARM: {
529+
auto const gty = type->AsGraphARM();
530+
std::vector<Operand> ops;
531+
ops.push_back(
532+
Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {gty->num_inputs()}));
533+
for (auto iotype : gty->io_types()) {
534+
uint32_t iotype_id = GetTypeInstruction(iotype);
535+
ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {iotype_id}));
536+
}
537+
typeInst = MakeUnique<Instruction>(context(), spv::Op::OpTypeGraphARM, 0,
538+
id, ops);
539+
break;
540+
}
528541
default:
529542
assert(false && "Unexpected type");
530543
break;
@@ -792,6 +805,15 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
792805
tensor_type->rank_id(), tensor_type->shape_id());
793806
break;
794807
}
808+
case Type::kGraphARM: {
809+
const GraphARM* graph_type = type.AsGraphARM();
810+
std::vector<const Type*> io_types;
811+
for (auto ioty : graph_type->io_types()) {
812+
io_types.push_back(RebuildType(GetId(ioty), *ioty));
813+
}
814+
rebuilt_ty = MakeUnique<GraphARM>(graph_type->num_inputs(), io_types);
815+
break;
816+
}
795817
default:
796818
assert(false && "Unhandled type");
797819
return nullptr;
@@ -1091,6 +1113,14 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
10911113
}
10921114
break;
10931115
}
1116+
case spv::Op::OpTypeGraphARM: {
1117+
std::vector<const Type*> io_types;
1118+
for (unsigned i = 1; i < inst.NumInOperands(); i++) {
1119+
io_types.push_back(GetType(inst.GetSingleWordInOperand(i)));
1120+
}
1121+
type = new GraphARM(inst.GetSingleWordInOperand(0), io_types);
1122+
break;
1123+
}
10941124
default:
10951125
assert(false && "Type not handled by the type manager.");
10961126
break;

source/opt/types.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ std::unique_ptr<Type> Type::Clone() const {
136136
DeclareKindCase(RayQueryKHR);
137137
DeclareKindCase(HitObjectNV);
138138
DeclareKindCase(TensorARM);
139+
DeclareKindCase(GraphARM);
139140
#undef DeclareKindCase
140141
default:
141142
assert(false && "Unhandled type");
@@ -189,6 +190,7 @@ bool Type::operator==(const Type& other) const {
189190
DeclareKindCase(TensorLayoutNV);
190191
DeclareKindCase(TensorViewNV);
191192
DeclareKindCase(TensorARM);
193+
DeclareKindCase(GraphARM);
192194
#undef DeclareKindCase
193195
default:
194196
assert(false && "Unhandled type");
@@ -250,6 +252,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
250252
DeclareKindCase(TensorLayoutNV);
251253
DeclareKindCase(TensorViewNV);
252254
DeclareKindCase(TensorARM);
255+
DeclareKindCase(GraphARM);
253256
#undef DeclareKindCase
254257
default:
255258
assert(false && "Unhandled type");
@@ -932,6 +935,61 @@ bool TensorARM::IsSameImpl(const Type* that, IsSameCache* seen) const {
932935
HasSameDecorations(that);
933936
}
934937

938+
GraphARM::GraphARM(const uint32_t num_inputs,
939+
const std::vector<const Type*>& io_types)
940+
: Type(kGraphARM), num_inputs_(num_inputs), io_types_(io_types) {
941+
assert(io_types.size() > 0);
942+
}
943+
944+
std::string GraphARM::str() const {
945+
std::ostringstream oss;
946+
oss << "graph<" << num_inputs_;
947+
for (auto ioty : io_types_) {
948+
oss << "," << ioty->str();
949+
}
950+
oss << ">";
951+
return oss.str();
952+
}
953+
954+
bool GraphARM::is_shaped() const {
955+
// A graph is considered to be shaped if all its interface tensors are shaped
956+
for (auto ioty : io_types_) {
957+
auto tensor_type = ioty->AsTensorARM();
958+
assert(tensor_type);
959+
if (!tensor_type->is_shaped()) {
960+
return false;
961+
}
962+
}
963+
return true;
964+
}
965+
966+
size_t GraphARM::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
967+
hash = hash_combine(hash, num_inputs_);
968+
for (auto ioty : io_types_) {
969+
hash = ioty->ComputeHashValue(hash, seen);
970+
}
971+
return hash;
972+
}
973+
974+
bool GraphARM::IsSameImpl(const Type* that, IsSameCache* seen) const {
975+
const GraphARM* og = that->AsGraphARM();
976+
if (!og) {
977+
return false;
978+
}
979+
if (num_inputs_ != og->num_inputs_) {
980+
return false;
981+
}
982+
if (io_types_.size() != og->io_types_.size()) {
983+
return false;
984+
}
985+
for (size_t i = 0; i < io_types_.size(); i++) {
986+
if (!io_types_[i]->IsSameImpl(og->io_types_[i], seen)) {
987+
return false;
988+
}
989+
}
990+
return true;
991+
}
992+
935993
} // namespace analysis
936994
} // namespace opt
937995
} // namespace spvtools

source/opt/types.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class HitObjectNV;
7070
class TensorLayoutNV;
7171
class TensorViewNV;
7272
class TensorARM;
73+
class GraphARM;
7374

7475
// Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
7576
// which is used as a way to probe the actual <subclass>.
@@ -116,6 +117,7 @@ class Type {
116117
kTensorLayoutNV,
117118
kTensorViewNV,
118119
kTensorARM,
120+
kGraphARM,
119121
kLast
120122
};
121123

@@ -223,6 +225,7 @@ class Type {
223225
DeclareCastMethod(TensorLayoutNV)
224226
DeclareCastMethod(TensorViewNV)
225227
DeclareCastMethod(TensorARM)
228+
DeclareCastMethod(GraphARM)
226229
#undef DeclareCastMethod
227230

228231
protected:
@@ -793,6 +796,8 @@ class TensorARM : public Type {
793796
const Type* element_type() const { return element_type_; }
794797
uint32_t rank_id() const { return rank_id_; }
795798
uint32_t shape_id() const { return shape_id_; }
799+
bool is_ranked() const { return rank_id_ != 0; }
800+
bool is_shaped() const { return shape_id_ != 0; }
796801

797802
private:
798803
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -802,6 +807,29 @@ class TensorARM : public Type {
802807
const uint32_t shape_id_;
803808
};
804809

810+
class GraphARM : public Type {
811+
public:
812+
GraphARM(const uint32_t num_inputs, const std::vector<const Type*>& io_types);
813+
GraphARM(const GraphARM&) = default;
814+
815+
std::string str() const override;
816+
817+
GraphARM* AsGraphARM() override { return this; }
818+
const GraphARM* AsGraphARM() const override { return this; }
819+
820+
uint32_t num_inputs() const { return num_inputs_; }
821+
const std::vector<const Type*>& io_types() const { return io_types_; }
822+
bool is_shaped() const;
823+
824+
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
825+
826+
private:
827+
bool IsSameImpl(const Type* that, IsSameCache*) const override;
828+
829+
const uint32_t num_inputs_;
830+
const std::vector<const Type*> io_types_;
831+
};
832+
805833
#define DefineParameterlessType(type, name) \
806834
class type : public Type { \
807835
public: \

test/opt/type_manager_test.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,16 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
184184

185185
// Tensors
186186
types.emplace_back(new TensorARM(f32));
187+
auto* tensor_f32 = types.back().get();
187188
types.emplace_back(new TensorARM(f32, 4));
189+
auto* tensor_f32_ranked = types.back().get();
188190
types.emplace_back(new TensorARM(f32, 4, 44));
191+
auto* tensor_f32_shaped = types.back().get();
192+
193+
// Graph
194+
types.emplace_back(new GraphARM(0, {tensor_f32}));
195+
types.emplace_back(new GraphARM(1, {tensor_f32_ranked, tensor_f32_ranked}));
196+
types.emplace_back(new GraphARM(1, {tensor_f32_shaped, tensor_f32_shaped}));
189197

190198
types.emplace_back(new TensorLayoutNV(1002, 1000));
191199
types.emplace_back(new TensorViewNV(1002, 1003, {1000, 1001}));
@@ -261,6 +269,9 @@ TEST(TypeManager, TypeStrings) {
261269
%ts = OpTypeTensorARM %u32
262270
%tsr = OpTypeTensorARM %u32 %id4
263271
%tss = OpTypeTensorARM %u32 %id4 %ts_shape
272+
%g_noin = OpTypeGraphARM 0 %ts
273+
%g_onein = OpTypeGraphARM 1 %tsr %tsr
274+
%g_shaped = OpTypeGraphARM 1 %tss %tss
264275
)";
265276

266277
std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
@@ -305,6 +316,11 @@ TEST(TypeManager, TypeStrings) {
305316
{44, "tensor<uint32, id(0), id(0)>"},
306317
{45, "tensor<uint32, id(6), id(0)>"},
307318
{46, "tensor<uint32, id(6), id(43)>"},
319+
{47, "graph<0,tensor<uint32, id(0), id(0)>>"},
320+
{48,
321+
"graph<1,tensor<uint32, id(6), id(0)>,tensor<uint32, id(6), id(0)>>"},
322+
{49,
323+
"graph<1,tensor<uint32, id(6), id(43)>,tensor<uint32, id(6), id(43)>>"},
308324
};
309325

310326
std::unique_ptr<IRContext> context =
@@ -1124,9 +1140,12 @@ TEST(TypeManager, GetTypeInstructionAllTypes) {
11241140
; CHECK: OpTypeCooperativeMatrixKHR [[f32]] [[uint8]] [[uint8]] [[uint8]] [[uint2]]
11251141
; CHECK: OpTypeRayQueryKHR
11261142
; CHECK: OpTypeHitObjectNV
1127-
; CHECK: OpTypeTensorARM [[f32]]
1128-
; CHECK: OpTypeTensorARM [[f32]] [[uint4]]
1129-
; CHECK: OpTypeTensorARM [[f32]] [[uint4]] [[uint_arr4_44]]
1143+
; CHECK: [[tensor_f32:%\w+]] = OpTypeTensorARM [[f32]]
1144+
; CHECK: [[tensor_f32_ranked:%\w+]] = OpTypeTensorARM [[f32]] [[uint4]]
1145+
; CHECK: [[tensor_f32_shaped:%\w+]] = OpTypeTensorARM [[f32]] [[uint4]] [[uint_arr4_44]]
1146+
; CHECK: OpTypeGraphARM 0 [[tensor_f32]]
1147+
; CHECK: OpTypeGraphARM 1 [[tensor_f32_ranked]] [[tensor_f32_ranked]]
1148+
; CHECK: OpTypeGraphARM 1 [[tensor_f32_shaped]] [[tensor_f32_shaped]]
11301149
OpCapability Shader
11311150
OpCapability Int64
11321151
OpCapability Linkage

0 commit comments

Comments
 (0)