Skip to content

Commit 86247be

Browse files
SamuelReederclaudebghimireamd
authored
[hipDNN] Add RMSNorm support to hipDNN frontend (#4823)
## Summary - Adds RMSNorm (Root Mean Square Normalization) frontend API to hipDNN - FlatBuffer schema (`rmsnorm_attributes.fbs`) with x, scale, epsilon, y (required) and inv_rms, bias (optional) - Frontend `RmsnormAttributes` class with serialization/deserialization - Frontend `RmsnormNode` with validation and property inference - JSON serialization support - 25 unit tests across frontend, data SDK, and test SDK (all passing) ## Details RMSNorm normalizes by root mean square without mean subtraction: `y = x / sqrt(mean(x^2) + epsilon) * scale`. This PR adds the frontend graph support — the reference implementation will follow in a stacked PR (ALMIOPEN-1098). --- Generated by `/orchestrate` [ALMIOPEN-1096]: https://amd-hub.atlassian.net/browse/ALMIOPEN-1096?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Bibek Ghimire <bghimire@amd.com>
1 parent 240a74b commit 86247be

File tree

24 files changed

+2988
-11
lines changed

24 files changed

+2988
-11
lines changed

projects/hipdnn/data_sdk/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ if(HIPDNN_GENERATE_SDK_HEADERS)
158158
schemas/matmul_attributes.fbs
159159
schemas/norm_common.fbs
160160
schemas/pointwise_attributes.fbs
161+
schemas/rmsnorm_attributes.fbs
161162
schemas/tensor_attributes.fbs
162163
SCHEMAS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/schemas"
163164
PRIMARY_VERSION ${HIPDNN_FLATBUFFERS_VERSION}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// automatically generated by the FlatBuffers compiler, do not modify
2+
3+
4+
#ifndef FLATBUFFERS_GENERATED_NORMCOMMON_HIPDNN_DATA_SDK_DATA_OBJECTS_H_
5+
#define FLATBUFFERS_GENERATED_NORMCOMMON_HIPDNN_DATA_SDK_DATA_OBJECTS_H_
6+
7+
#include "flatbuffers/flatbuffers.h"
8+
9+
// Ensure the included flatbuffers.h is the same version as when this file was
10+
// generated, otherwise it may not be compatible.
11+
static_assert(FLATBUFFERS_VERSION_MAJOR == 25 &&
12+
FLATBUFFERS_VERSION_MINOR == 9 &&
13+
FLATBUFFERS_VERSION_REVISION == 23,
14+
"Non-compatible flatbuffers version included");
15+
16+
17+
namespace hipdnn_data_sdk {
18+
namespace data_objects {
19+
20+
enum class NormFwdPhase : int8_t {
21+
NOT_SET = 0,
22+
INFERENCE = 1,
23+
TRAINING = 2,
24+
MIN = NOT_SET,
25+
MAX = TRAINING
26+
};
27+
28+
inline const NormFwdPhase (&EnumValuesNormFwdPhase())[3] {
29+
static const NormFwdPhase values[] = {
30+
NormFwdPhase::NOT_SET,
31+
NormFwdPhase::INFERENCE,
32+
NormFwdPhase::TRAINING
33+
};
34+
return values;
35+
}
36+
37+
inline const char * const *EnumNamesNormFwdPhase() {
38+
static const char * const names[4] = {
39+
"NOT_SET",
40+
"INFERENCE",
41+
"TRAINING",
42+
nullptr
43+
};
44+
return names;
45+
}
46+
47+
inline const char *EnumNameNormFwdPhase(NormFwdPhase e) {
48+
if (::flatbuffers::IsOutRange(e, NormFwdPhase::NOT_SET, NormFwdPhase::TRAINING)) return "";
49+
const size_t index = static_cast<size_t>(e);
50+
return EnumNamesNormFwdPhase()[index];
51+
}
52+
53+
} // namespace data_objects
54+
} // namespace hipdnn_data_sdk
55+
56+
#endif // FLATBUFFERS_GENERATED_NORMCOMMON_HIPDNN_DATA_SDK_DATA_OBJECTS_H_
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
// automatically generated by the FlatBuffers compiler, do not modify
2+
3+
4+
#ifndef FLATBUFFERS_GENERATED_RMSNORMATTRIBUTES_HIPDNN_DATA_SDK_DATA_OBJECTS_H_
5+
#define FLATBUFFERS_GENERATED_RMSNORMATTRIBUTES_HIPDNN_DATA_SDK_DATA_OBJECTS_H_
6+
7+
#include "flatbuffers/flatbuffers.h"
8+
9+
// Ensure the included flatbuffers.h is the same version as when this file was
10+
// generated, otherwise it may not be compatible.
11+
static_assert(FLATBUFFERS_VERSION_MAJOR == 25 &&
12+
FLATBUFFERS_VERSION_MINOR == 9 &&
13+
FLATBUFFERS_VERSION_REVISION == 23,
14+
"Non-compatible flatbuffers version included");
15+
16+
#include "norm_common_generated.h"
17+
18+
namespace hipdnn_data_sdk {
19+
namespace data_objects {
20+
21+
struct RMSNormAttributes;
22+
struct RMSNormAttributesBuilder;
23+
struct RMSNormAttributesT;
24+
25+
bool operator==(const RMSNormAttributesT &lhs, const RMSNormAttributesT &rhs);
26+
bool operator!=(const RMSNormAttributesT &lhs, const RMSNormAttributesT &rhs);
27+
28+
struct RMSNormAttributesT : public ::flatbuffers::NativeTable {
29+
typedef RMSNormAttributes TableType;
30+
int64_t x_tensor_uid = 0;
31+
int64_t scale_tensor_uid = 0;
32+
int64_t epsilon_tensor_uid = 0;
33+
int64_t y_tensor_uid = 0;
34+
::flatbuffers::Optional<int64_t> bias_tensor_uid = ::flatbuffers::nullopt;
35+
::flatbuffers::Optional<int64_t> inv_rms_tensor_uid = ::flatbuffers::nullopt;
36+
hipdnn_data_sdk::data_objects::NormFwdPhase forward_phase = hipdnn_data_sdk::data_objects::NormFwdPhase::NOT_SET;
37+
};
38+
39+
struct RMSNormAttributes FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
40+
typedef RMSNormAttributesT NativeTableType;
41+
typedef RMSNormAttributesBuilder Builder;
42+
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
43+
VT_X_TENSOR_UID = 4,
44+
VT_SCALE_TENSOR_UID = 6,
45+
VT_EPSILON_TENSOR_UID = 8,
46+
VT_Y_TENSOR_UID = 10,
47+
VT_BIAS_TENSOR_UID = 12,
48+
VT_INV_RMS_TENSOR_UID = 14,
49+
VT_FORWARD_PHASE = 16
50+
};
51+
int64_t x_tensor_uid() const {
52+
return GetField<int64_t>(VT_X_TENSOR_UID, 0);
53+
}
54+
bool mutate_x_tensor_uid(int64_t _x_tensor_uid = 0) {
55+
return SetField<int64_t>(VT_X_TENSOR_UID, _x_tensor_uid, 0);
56+
}
57+
int64_t scale_tensor_uid() const {
58+
return GetField<int64_t>(VT_SCALE_TENSOR_UID, 0);
59+
}
60+
bool mutate_scale_tensor_uid(int64_t _scale_tensor_uid = 0) {
61+
return SetField<int64_t>(VT_SCALE_TENSOR_UID, _scale_tensor_uid, 0);
62+
}
63+
int64_t epsilon_tensor_uid() const {
64+
return GetField<int64_t>(VT_EPSILON_TENSOR_UID, 0);
65+
}
66+
bool mutate_epsilon_tensor_uid(int64_t _epsilon_tensor_uid = 0) {
67+
return SetField<int64_t>(VT_EPSILON_TENSOR_UID, _epsilon_tensor_uid, 0);
68+
}
69+
int64_t y_tensor_uid() const {
70+
return GetField<int64_t>(VT_Y_TENSOR_UID, 0);
71+
}
72+
bool mutate_y_tensor_uid(int64_t _y_tensor_uid = 0) {
73+
return SetField<int64_t>(VT_Y_TENSOR_UID, _y_tensor_uid, 0);
74+
}
75+
::flatbuffers::Optional<int64_t> bias_tensor_uid() const {
76+
return GetOptional<int64_t, int64_t>(VT_BIAS_TENSOR_UID);
77+
}
78+
bool mutate_bias_tensor_uid(int64_t _bias_tensor_uid) {
79+
return SetField<int64_t>(VT_BIAS_TENSOR_UID, _bias_tensor_uid);
80+
}
81+
::flatbuffers::Optional<int64_t> inv_rms_tensor_uid() const {
82+
return GetOptional<int64_t, int64_t>(VT_INV_RMS_TENSOR_UID);
83+
}
84+
bool mutate_inv_rms_tensor_uid(int64_t _inv_rms_tensor_uid) {
85+
return SetField<int64_t>(VT_INV_RMS_TENSOR_UID, _inv_rms_tensor_uid);
86+
}
87+
hipdnn_data_sdk::data_objects::NormFwdPhase forward_phase() const {
88+
return static_cast<hipdnn_data_sdk::data_objects::NormFwdPhase>(GetField<int8_t>(VT_FORWARD_PHASE, 0));
89+
}
90+
bool mutate_forward_phase(hipdnn_data_sdk::data_objects::NormFwdPhase _forward_phase = static_cast<hipdnn_data_sdk::data_objects::NormFwdPhase>(0)) {
91+
return SetField<int8_t>(VT_FORWARD_PHASE, static_cast<int8_t>(_forward_phase), 0);
92+
}
93+
bool Verify(::flatbuffers::Verifier &verifier) const {
94+
return VerifyTableStart(verifier) &&
95+
VerifyField<int64_t>(verifier, VT_X_TENSOR_UID, 8) &&
96+
VerifyField<int64_t>(verifier, VT_SCALE_TENSOR_UID, 8) &&
97+
VerifyField<int64_t>(verifier, VT_EPSILON_TENSOR_UID, 8) &&
98+
VerifyField<int64_t>(verifier, VT_Y_TENSOR_UID, 8) &&
99+
VerifyField<int64_t>(verifier, VT_BIAS_TENSOR_UID, 8) &&
100+
VerifyField<int64_t>(verifier, VT_INV_RMS_TENSOR_UID, 8) &&
101+
VerifyField<int8_t>(verifier, VT_FORWARD_PHASE, 1) &&
102+
verifier.EndTable();
103+
}
104+
RMSNormAttributesT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const;
105+
void UnPackTo(RMSNormAttributesT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const;
106+
static ::flatbuffers::Offset<RMSNormAttributes> Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RMSNormAttributesT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr);
107+
};
108+
109+
struct RMSNormAttributesBuilder {
110+
typedef RMSNormAttributes Table;
111+
::flatbuffers::FlatBufferBuilder &fbb_;
112+
::flatbuffers::uoffset_t start_;
113+
void add_x_tensor_uid(int64_t x_tensor_uid) {
114+
fbb_.AddElement<int64_t>(RMSNormAttributes::VT_X_TENSOR_UID, x_tensor_uid, 0);
115+
}
116+
void add_scale_tensor_uid(int64_t scale_tensor_uid) {
117+
fbb_.AddElement<int64_t>(RMSNormAttributes::VT_SCALE_TENSOR_UID, scale_tensor_uid, 0);
118+
}
119+
void add_epsilon_tensor_uid(int64_t epsilon_tensor_uid) {
120+
fbb_.AddElement<int64_t>(RMSNormAttributes::VT_EPSILON_TENSOR_UID, epsilon_tensor_uid, 0);
121+
}
122+
void add_y_tensor_uid(int64_t y_tensor_uid) {
123+
fbb_.AddElement<int64_t>(RMSNormAttributes::VT_Y_TENSOR_UID, y_tensor_uid, 0);
124+
}
125+
void add_bias_tensor_uid(int64_t bias_tensor_uid) {
126+
fbb_.AddElement<int64_t>(RMSNormAttributes::VT_BIAS_TENSOR_UID, bias_tensor_uid);
127+
}
128+
void add_inv_rms_tensor_uid(int64_t inv_rms_tensor_uid) {
129+
fbb_.AddElement<int64_t>(RMSNormAttributes::VT_INV_RMS_TENSOR_UID, inv_rms_tensor_uid);
130+
}
131+
void add_forward_phase(hipdnn_data_sdk::data_objects::NormFwdPhase forward_phase) {
132+
fbb_.AddElement<int8_t>(RMSNormAttributes::VT_FORWARD_PHASE, static_cast<int8_t>(forward_phase), 0);
133+
}
134+
explicit RMSNormAttributesBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
135+
: fbb_(_fbb) {
136+
start_ = fbb_.StartTable();
137+
}
138+
::flatbuffers::Offset<RMSNormAttributes> Finish() {
139+
const auto end = fbb_.EndTable(start_);
140+
auto o = ::flatbuffers::Offset<RMSNormAttributes>(end);
141+
return o;
142+
}
143+
};
144+
145+
inline ::flatbuffers::Offset<RMSNormAttributes> CreateRMSNormAttributes(
146+
::flatbuffers::FlatBufferBuilder &_fbb,
147+
int64_t x_tensor_uid = 0,
148+
int64_t scale_tensor_uid = 0,
149+
int64_t epsilon_tensor_uid = 0,
150+
int64_t y_tensor_uid = 0,
151+
::flatbuffers::Optional<int64_t> bias_tensor_uid = ::flatbuffers::nullopt,
152+
::flatbuffers::Optional<int64_t> inv_rms_tensor_uid = ::flatbuffers::nullopt,
153+
hipdnn_data_sdk::data_objects::NormFwdPhase forward_phase = hipdnn_data_sdk::data_objects::NormFwdPhase::NOT_SET) {
154+
RMSNormAttributesBuilder builder_(_fbb);
155+
if(inv_rms_tensor_uid) { builder_.add_inv_rms_tensor_uid(*inv_rms_tensor_uid); }
156+
if(bias_tensor_uid) { builder_.add_bias_tensor_uid(*bias_tensor_uid); }
157+
builder_.add_y_tensor_uid(y_tensor_uid);
158+
builder_.add_epsilon_tensor_uid(epsilon_tensor_uid);
159+
builder_.add_scale_tensor_uid(scale_tensor_uid);
160+
builder_.add_x_tensor_uid(x_tensor_uid);
161+
builder_.add_forward_phase(forward_phase);
162+
return builder_.Finish();
163+
}
164+
165+
::flatbuffers::Offset<RMSNormAttributes> CreateRMSNormAttributes(::flatbuffers::FlatBufferBuilder &_fbb, const RMSNormAttributesT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr);
166+
167+
168+
inline bool operator==(const RMSNormAttributesT &lhs, const RMSNormAttributesT &rhs) {
169+
return
170+
(lhs.x_tensor_uid == rhs.x_tensor_uid) &&
171+
(lhs.scale_tensor_uid == rhs.scale_tensor_uid) &&
172+
(lhs.epsilon_tensor_uid == rhs.epsilon_tensor_uid) &&
173+
(lhs.y_tensor_uid == rhs.y_tensor_uid) &&
174+
(lhs.bias_tensor_uid == rhs.bias_tensor_uid) &&
175+
(lhs.inv_rms_tensor_uid == rhs.inv_rms_tensor_uid) &&
176+
(lhs.forward_phase == rhs.forward_phase);
177+
}
178+
179+
inline bool operator!=(const RMSNormAttributesT &lhs, const RMSNormAttributesT &rhs) {
180+
return !(lhs == rhs);
181+
}
182+
183+
184+
inline RMSNormAttributesT *RMSNormAttributes::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const {
185+
auto _o = std::unique_ptr<RMSNormAttributesT>(new RMSNormAttributesT());
186+
UnPackTo(_o.get(), _resolver);
187+
return _o.release();
188+
}
189+
190+
inline void RMSNormAttributes::UnPackTo(RMSNormAttributesT *_o, const ::flatbuffers::resolver_function_t *_resolver) const {
191+
(void)_o;
192+
(void)_resolver;
193+
{ auto _e = x_tensor_uid(); _o->x_tensor_uid = _e; }
194+
{ auto _e = scale_tensor_uid(); _o->scale_tensor_uid = _e; }
195+
{ auto _e = epsilon_tensor_uid(); _o->epsilon_tensor_uid = _e; }
196+
{ auto _e = y_tensor_uid(); _o->y_tensor_uid = _e; }
197+
{ auto _e = bias_tensor_uid(); _o->bias_tensor_uid = _e; }
198+
{ auto _e = inv_rms_tensor_uid(); _o->inv_rms_tensor_uid = _e; }
199+
{ auto _e = forward_phase(); _o->forward_phase = _e; }
200+
}
201+
202+
inline ::flatbuffers::Offset<RMSNormAttributes> RMSNormAttributes::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RMSNormAttributesT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) {
203+
return CreateRMSNormAttributes(_fbb, _o, _rehasher);
204+
}
205+
206+
inline ::flatbuffers::Offset<RMSNormAttributes> CreateRMSNormAttributes(::flatbuffers::FlatBufferBuilder &_fbb, const RMSNormAttributesT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) {
207+
(void)_rehasher;
208+
(void)_o;
209+
struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const RMSNormAttributesT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
210+
auto _x_tensor_uid = _o->x_tensor_uid;
211+
auto _scale_tensor_uid = _o->scale_tensor_uid;
212+
auto _epsilon_tensor_uid = _o->epsilon_tensor_uid;
213+
auto _y_tensor_uid = _o->y_tensor_uid;
214+
auto _bias_tensor_uid = _o->bias_tensor_uid;
215+
auto _inv_rms_tensor_uid = _o->inv_rms_tensor_uid;
216+
auto _forward_phase = _o->forward_phase;
217+
return hipdnn_data_sdk::data_objects::CreateRMSNormAttributes(
218+
_fbb,
219+
_x_tensor_uid,
220+
_scale_tensor_uid,
221+
_epsilon_tensor_uid,
222+
_y_tensor_uid,
223+
_bias_tensor_uid,
224+
_inv_rms_tensor_uid,
225+
_forward_phase);
226+
}
227+
228+
} // namespace data_objects
229+
} // namespace hipdnn_data_sdk
230+
231+
#endif // FLATBUFFERS_GENERATED_RMSNORMATTRIBUTES_HIPDNN_DATA_SDK_DATA_OBJECTS_H_

0 commit comments

Comments
 (0)