@@ -34,30 +34,38 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot(
3434 has_clamp);
3535
3636 std::vector<char > activation_data (
37- activation_data_size<has_weight_zeros> (m, k, group_size));
38- prepare_activation_data<has_weight_zeros> (
37+ activation_data_size (m, k, group_size, has_weight_zeros ));
38+ prepare_activation_data (
3939 (void *)activation_data.data (),
4040 m,
4141 k,
4242 group_size,
43- test_case.activations .data ());
43+ test_case.activations .data (),
44+ has_weight_zeros);
4445
45- std::vector<char > weight_data (
46- weight_data_size<weight_nbit, has_weight_zeros, has_bias>(
47- n, k, group_size));
48- prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>(
46+ std::vector<char > weight_data (weight_data_size<weight_nbit>(
47+ n, k, group_size, has_weight_zeros, has_bias));
48+ int8_t * weight_zeros_ptr = nullptr ;
49+ if (has_weight_zeros) {
50+ weight_zeros_ptr = test_case.weight_zeros .data ();
51+ }
52+ float * bias_ptr = nullptr ;
53+ if (has_bias) {
54+ bias_ptr = test_case.bias .data ();
55+ }
56+ prepare_weight_data<weight_nbit>(
4957 (void *)weight_data.data (),
5058 n,
5159 k,
5260 group_size,
5361 test_case.weight_qvals .data (),
5462 test_case.weight_scales .data (),
55- test_case. weight_zeros . data () ,
56- test_case. bias . data () );
63+ weight_zeros_ptr ,
64+ bias_ptr );
5765
5866 std::vector<float > output (m * k);
5967 for (auto _ : state) {
60- kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp >(
68+ kernel<weight_nbit>(
6169 output.data (),
6270 /* output_m_stride=*/ n,
6371 m,
@@ -67,7 +75,10 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot(
6775 weight_data.data (),
6876 activation_data.data (),
6977 test_case.clamp_min ,
70- test_case.clamp_max );
78+ test_case.clamp_max ,
79+ has_weight_zeros,
80+ has_bias,
81+ has_clamp);
7182 }
7283}
7384
@@ -95,30 +106,38 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot(
95106 has_clamp);
96107
97108 std::vector<char > activation_data (
98- activation_data_size<has_weight_zeros> (m, k, group_size));
99- prepare_activation_data<has_weight_zeros> (
109+ activation_data_size (m, k, group_size, has_weight_zeros ));
110+ prepare_activation_data (
100111 (void *)activation_data.data (),
101112 m,
102113 k,
103114 group_size,
104- test_case.activations .data ());
115+ test_case.activations .data (),
116+ has_weight_zeros);
105117
106- std::vector<char > weight_data (
107- weight_data_size<weight_nbit, has_weight_zeros, has_bias>(
108- n, k, group_size));
109- prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>(
118+ std::vector<char > weight_data (weight_data_size<weight_nbit>(
119+ n, k, group_size, has_weight_zeros, has_bias));
120+ int8_t * weight_zeros_ptr = nullptr ;
121+ if (has_weight_zeros) {
122+ weight_zeros_ptr = test_case.weight_zeros .data ();
123+ }
124+ float * bias_ptr = nullptr ;
125+ if (has_bias) {
126+ bias_ptr = test_case.bias .data ();
127+ }
128+ prepare_weight_data<weight_nbit>(
110129 (void *)weight_data.data (),
111130 n,
112131 k,
113132 group_size,
114133 test_case.weight_qvals .data (),
115134 test_case.weight_scales .data (),
116- test_case. weight_zeros . data () ,
117- test_case. bias . data () );
135+ weight_zeros_ptr ,
136+ bias_ptr );
118137
119138 std::vector<float > output (m * k);
120139 for (auto _ : state) {
121- kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp >(
140+ kernel<weight_nbit>(
122141 output.data (),
123142 /* output_m_stride=*/ n,
124143 m,
@@ -128,7 +147,10 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot(
128147 weight_data.data (),
129148 activation_data.data (),
130149 test_case.clamp_min ,
131- test_case.clamp_max );
150+ test_case.clamp_max ,
151+ has_weight_zeros,
152+ has_bias,
153+ has_clamp);
132154 }
133155}
134156
@@ -156,30 +178,38 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
156178 has_clamp);
157179
158180 std::vector<char > activation_data (
159- activation_data_size<has_weight_zeros> (m, k, group_size));
160- prepare_activation_data<has_weight_zeros> (
181+ activation_data_size (m, k, group_size, has_weight_zeros ));
182+ prepare_activation_data (
161183 (void *)activation_data.data (),
162184 m,
163185 k,
164186 group_size,
165- test_case.activations .data ());
187+ test_case.activations .data (),
188+ has_weight_zeros);
166189
167- std::vector<char > weight_data (
168- weight_data_size<weight_nbit, has_weight_zeros, has_bias>(
169- n, k, group_size));
170- prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>(
190+ std::vector<char > weight_data (weight_data_size<weight_nbit>(
191+ n, k, group_size, has_weight_zeros, has_bias));
192+ int8_t * weight_zeros_ptr = nullptr ;
193+ if (has_weight_zeros) {
194+ weight_zeros_ptr = test_case.weight_zeros .data ();
195+ }
196+ float * bias_ptr = nullptr ;
197+ if (has_bias) {
198+ bias_ptr = test_case.bias .data ();
199+ }
200+ prepare_weight_data<weight_nbit>(
171201 (void *)weight_data.data (),
172202 n,
173203 k,
174204 group_size,
175205 test_case.weight_qvals .data (),
176206 test_case.weight_scales .data (),
177- test_case. weight_zeros . data () ,
178- test_case. bias . data () );
207+ weight_zeros_ptr ,
208+ bias_ptr );
179209
180210 std::vector<float > output (m * k);
181211 for (auto _ : state) {
182- kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp >(
212+ kernel<weight_nbit>(
183213 output.data (),
184214 /* output_m_stride=*/ n,
185215 m,
@@ -189,7 +219,10 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
189219 weight_data.data (),
190220 activation_data.data (),
191221 test_case.clamp_min ,
192- test_case.clamp_max );
222+ test_case.clamp_max ,
223+ has_weight_zeros,
224+ has_bias,
225+ has_clamp);
193226 }
194227}
195228
0 commit comments