@@ -72,6 +72,7 @@ NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType
7272
7373const char *opCodeToString (NnOpCode code) {
7474 if (code == OP_MERGE_ADD) return " MERGE_ADD" ;
75+ if (code == OP_MERGE_SUM) return " MERGE_SUM" ;
7576 if (code == OP_EMBEDDING) return " EMBEDDING" ;
7677 if (code == OP_INV_RMS) return " INV_RMS" ;
7778 if (code == OP_RMS_NORM) return " RMS_NORM" ;
@@ -81,7 +82,11 @@ const char *opCodeToString(NnOpCode code) {
8182 if (code == OP_GELU) return " GELU" ;
8283 if (code == OP_SILU) return " SILU" ;
8384 if (code == OP_MUL) return " MUL" ;
85+ if (code == OP_SCALE) return " SCALE" ;
8486 if (code == OP_CAST) return " CAST" ;
87+ if (code == OP_REPEAT_Z) return " REPEAT_Z" ;
88+ if (code == OP_SHIFT) return " SHIFT" ;
89+ if (code == OP_MOE_GATE) return " MOE_GATE" ;
8590 throw std::invalid_argument (" Unknown op code" );
8691}
8792
@@ -97,17 +102,22 @@ const char *opQuantTypeToString(NnOpQuantType type) {
97102 throw std::invalid_argument (" Unknown op quant type" );
98103}
99104
100- NnSize2D size0 () {
101- return { F_UNK, 0 , 0 , 0 , 0 };
105+ NnSize3D size0 () {
106+ return { F_UNK, 0 , 0 , 0 , 0 , 0 };
102107}
103108
104- NnSize2D size1D (NnFloatType floatType, NnUint x) {
105- return size2D (floatType, 1 , x);
109+ NnSize3D size1D (NnFloatType floatType, NnUint x) {
110+ return size3D (floatType, 1 , 1 , x);
106111}
107112
108- NnSize2D size2D (NnFloatType floatType, NnUint y, NnUint x) {
109- NnSize length = y * x;
110- return { floatType, y, x, length, getBytes (floatType, length) };
113+ NnSize3D size2D (NnFloatType floatType, NnUint y, NnUint x) {
114+ return size3D (floatType, 1 , y, x);
115+ }
116+
117+ NnSize3D size3D (NnFloatType floatType, NnUint z, NnUint y, NnUint x) {
118+ NnSize len = z * y * x;
119+ NnSize lenXY = y * x;
120+ return { floatType, z, y, x, len, getBytes (floatType, len), getBytes (floatType, lenXY) };
111121}
112122
113123NnPointerConfig pointerBatchConfig (NnPointerSource source, NnUint index) {
0 commit comments