Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
#include "layernorm_attributes_generated.h"
#include "matmul_attributes_generated.h"
#include "pointwise_attributes_generated.h"
#include "reduction_attributes_generated.h"
#include "rmsnorm_attributes_generated.h"
#include "sdpa_attributes_generated.h"
#include "sdpa_backward_attributes_generated.h"
Expand Down Expand Up @@ -66,11 +67,12 @@ enum class NodeAttributes : uint8_t {
BlockScaleQuantizeAttributes = 14,
SdpaBackwardAttributes = 15,
CustomOpAttributes = 16,
ReductionAttributes = 17,
MIN = NONE,
MAX = CustomOpAttributes
MAX = ReductionAttributes
};

inline const NodeAttributes (&EnumValuesNodeAttributes())[17] {
inline const NodeAttributes (&EnumValuesNodeAttributes())[18] {
static const NodeAttributes values[] = {
NodeAttributes::NONE,
NodeAttributes::BatchnormInferenceAttributes,
Expand All @@ -88,13 +90,14 @@ inline const NodeAttributes (&EnumValuesNodeAttributes())[17] {
NodeAttributes::BlockScaleDequantizeAttributes,
NodeAttributes::BlockScaleQuantizeAttributes,
NodeAttributes::SdpaBackwardAttributes,
NodeAttributes::CustomOpAttributes
NodeAttributes::CustomOpAttributes,
NodeAttributes::ReductionAttributes
};
return values;
}

inline const char * const *EnumNamesNodeAttributes() {
static const char * const names[18] = {
static const char * const names[19] = {
"NONE",
"BatchnormInferenceAttributes",
"PointwiseAttributes",
Expand All @@ -112,13 +115,14 @@ inline const char * const *EnumNamesNodeAttributes() {
"BlockScaleQuantizeAttributes",
"SdpaBackwardAttributes",
"CustomOpAttributes",
"ReductionAttributes",
nullptr
};
return names;
}

inline const char *EnumNameNodeAttributes(NodeAttributes e) {
if (::flatbuffers::IsOutRange(e, NodeAttributes::NONE, NodeAttributes::CustomOpAttributes)) return "";
if (::flatbuffers::IsOutRange(e, NodeAttributes::NONE, NodeAttributes::ReductionAttributes)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesNodeAttributes()[index];
}
Expand Down Expand Up @@ -191,6 +195,10 @@ template<> struct NodeAttributesTraits<hipdnn_data_sdk::data_objects::CustomOpAt
static const NodeAttributes enum_value = NodeAttributes::CustomOpAttributes;
};

template<> struct NodeAttributesTraits<hipdnn_data_sdk::data_objects::ReductionAttributes> {
static const NodeAttributes enum_value = NodeAttributes::ReductionAttributes;
};

template<typename T> struct NodeAttributesUnionTraits {
static const NodeAttributes enum_value = NodeAttributes::NONE;
};
Expand Down Expand Up @@ -259,6 +267,10 @@ template<> struct NodeAttributesUnionTraits<hipdnn_data_sdk::data_objects::Custo
static const NodeAttributes enum_value = NodeAttributes::CustomOpAttributes;
};

template<> struct NodeAttributesUnionTraits<hipdnn_data_sdk::data_objects::ReductionAttributesT> {
static const NodeAttributes enum_value = NodeAttributes::ReductionAttributes;
};

struct NodeAttributesUnion {
NodeAttributes type;
void *value;
Expand Down Expand Up @@ -417,6 +429,14 @@ struct NodeAttributesUnion {
return type == NodeAttributes::CustomOpAttributes ?
reinterpret_cast<const hipdnn_data_sdk::data_objects::CustomOpAttributesT *>(value) : nullptr;
}
hipdnn_data_sdk::data_objects::ReductionAttributesT *AsReductionAttributes() {
return type == NodeAttributes::ReductionAttributes ?
reinterpret_cast<hipdnn_data_sdk::data_objects::ReductionAttributesT *>(value) : nullptr;
}
const hipdnn_data_sdk::data_objects::ReductionAttributesT *AsReductionAttributes() const {
return type == NodeAttributes::ReductionAttributes ?
reinterpret_cast<const hipdnn_data_sdk::data_objects::ReductionAttributesT *>(value) : nullptr;
}
};


Expand Down Expand Up @@ -490,6 +510,10 @@ inline bool operator==(const NodeAttributesUnion &lhs, const NodeAttributesUnion
return *(reinterpret_cast<const hipdnn_data_sdk::data_objects::CustomOpAttributesT *>(lhs.value)) ==
*(reinterpret_cast<const hipdnn_data_sdk::data_objects::CustomOpAttributesT *>(rhs.value));
}
case NodeAttributes::ReductionAttributes: {
return *(reinterpret_cast<const hipdnn_data_sdk::data_objects::ReductionAttributesT *>(lhs.value)) ==
*(reinterpret_cast<const hipdnn_data_sdk::data_objects::ReductionAttributesT *>(rhs.value));
}
default: {
return false;
}
Expand Down Expand Up @@ -586,6 +610,9 @@ struct Node FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
const hipdnn_data_sdk::data_objects::CustomOpAttributes *attributes_as_CustomOpAttributes() const {
return attributes_type() == hipdnn_data_sdk::data_objects::NodeAttributes::CustomOpAttributes ? static_cast<const hipdnn_data_sdk::data_objects::CustomOpAttributes *>(attributes()) : nullptr;
}
const hipdnn_data_sdk::data_objects::ReductionAttributes *attributes_as_ReductionAttributes() const {
return attributes_type() == hipdnn_data_sdk::data_objects::NodeAttributes::ReductionAttributes ? static_cast<const hipdnn_data_sdk::data_objects::ReductionAttributes *>(attributes()) : nullptr;
}
void *mutable_attributes() {
return GetPointer<void *>(VT_ATTRIBUTES);
}
Expand Down Expand Up @@ -668,6 +695,10 @@ template<> inline const hipdnn_data_sdk::data_objects::CustomOpAttributes *Node:
return attributes_as_CustomOpAttributes();
}

template<> inline const hipdnn_data_sdk::data_objects::ReductionAttributes *Node::attributes_as<hipdnn_data_sdk::data_objects::ReductionAttributes>() const {
return attributes_as_ReductionAttributes();
}

struct NodeBuilder {
typedef Node Table;
::flatbuffers::FlatBufferBuilder &fbb_;
Expand Down Expand Up @@ -1098,6 +1129,10 @@ inline bool VerifyNodeAttributes(::flatbuffers::Verifier &verifier, const void *
auto ptr = reinterpret_cast<const hipdnn_data_sdk::data_objects::CustomOpAttributes *>(obj);
return verifier.VerifyTable(ptr);
}
case NodeAttributes::ReductionAttributes: {
auto ptr = reinterpret_cast<const hipdnn_data_sdk::data_objects::ReductionAttributes *>(obj);
return verifier.VerifyTable(ptr);
}
default: return true;
}
}
Expand Down Expand Up @@ -1181,6 +1216,10 @@ inline void *NodeAttributesUnion::UnPack(const void *obj, NodeAttributes type, c
auto ptr = reinterpret_cast<const hipdnn_data_sdk::data_objects::CustomOpAttributes *>(obj);
return ptr->UnPack(resolver);
}
case NodeAttributes::ReductionAttributes: {
auto ptr = reinterpret_cast<const hipdnn_data_sdk::data_objects::ReductionAttributes *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
Expand Down Expand Up @@ -1252,6 +1291,10 @@ inline ::flatbuffers::Offset<void> NodeAttributesUnion::Pack(::flatbuffers::Flat
auto ptr = reinterpret_cast<const hipdnn_data_sdk::data_objects::CustomOpAttributesT *>(value);
return CreateCustomOpAttributes(_fbb, ptr, _rehasher).Union();
}
case NodeAttributes::ReductionAttributes: {
auto ptr = reinterpret_cast<const hipdnn_data_sdk::data_objects::ReductionAttributesT *>(value);
return CreateReductionAttributes(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
Expand Down Expand Up @@ -1322,6 +1365,10 @@ inline NodeAttributesUnion::NodeAttributesUnion(const NodeAttributesUnion &u) :
value = new hipdnn_data_sdk::data_objects::CustomOpAttributesT(*reinterpret_cast<hipdnn_data_sdk::data_objects::CustomOpAttributesT *>(u.value));
break;
}
case NodeAttributes::ReductionAttributes: {
value = new hipdnn_data_sdk::data_objects::ReductionAttributesT(*reinterpret_cast<hipdnn_data_sdk::data_objects::ReductionAttributesT *>(u.value));
break;
}
default:
break;
}
Expand Down Expand Up @@ -1409,6 +1456,11 @@ inline void NodeAttributesUnion::Reset() {
delete ptr;
break;
}
case NodeAttributes::ReductionAttributes: {
auto ptr = reinterpret_cast<hipdnn_data_sdk::data_objects::ReductionAttributesT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;
Expand Down
Loading
Loading