@@ -181,6 +181,62 @@ inline void groupwise_lowbit_weight_lut_kernel_1x4x32(
181
181
has_bias,
182
182
has_clamp);
183
183
}
184
+
185
+ /* *
186
+ * @brief Calculates the byte offset for a specific row in the packed activation
187
+ * buffer.
188
+ *
189
+ * @param m_idx The row index for which to calculate the offset.
190
+ * @param k The K dimension (width) of the activation matrix.
191
+ * @return The byte offset from the start of the buffer.
192
+ */
193
+ inline size_t packed_activations_offset (int m_idx, int k) {
194
+ // For a simple padded row-major format, the offset is just m_idx * k.
195
+ return sizeof (float ) * m_idx * k;
196
+ }
197
+
198
+ /* *
199
+ * @brief Calculates the byte offset for a given column index in the packed
200
+ * weights buffer. The buffer is assumed to be laid out as a series of
201
+ * contiguous blocks, where each block contains `nr` packed columns.
202
+ *
203
+ * @param n_idx The starting column index of the tile. Must be a multiple of
204
+ * `nr`.
205
+ * @param k The inner dimension of the matrix.
206
+ * @param weight_nbit The number of bits for the quantized weights.
207
+ * @param has_scales Whether weight scales are present.
208
+ * @param has_bias Whether a bias vector is packed.
209
+ * @param nr The micro-kernel tiling parameter for the N dimension.
210
+ * @param kr The micro-kernel tiling parameter for the K dimension.
211
+ * @return The byte offset into the packed weights buffer.
212
+ */
213
+ inline size_t packed_weights_offset (
214
+ int n_idx,
215
+ int k,
216
+ int weight_nbit,
217
+ int scale_group_size,
218
+ bool has_scales,
219
+ bool has_bias,
220
+ int nr,
221
+ int kr,
222
+ int sr) {
223
+ (void )sr; // unused
224
+ assert (n_idx % nr == 0 );
225
+
226
+ const size_t packed_tile_size_for_nr_cols = packed_weights_size (
227
+ /* n=*/ nr, // The size we are calculating is for a single tile of width
228
+ // `nr`.
229
+ k,
230
+ weight_nbit,
231
+ scale_group_size,
232
+ has_scales,
233
+ has_bias,
234
+ nr,
235
+ kr,
236
+ sr);
237
+
238
+ return (n_idx / nr) * packed_tile_size_for_nr_cols;
239
+ }
184
240
} // namespace
185
241
// torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut
186
242
0 commit comments