Skip to content

Commit d9f8a68

Browse files
authored
Add offset function for activation and weight for multithread. (#2514) (#2521)
Summary: Pull Request resolved: #2514 Reviewed By: metascroy Differential Revision: D77619995
1 parent f24f37b commit d9f8a68

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,62 @@ inline void groupwise_lowbit_weight_lut_kernel_1x4x32(
181181
has_bias,
182182
has_clamp);
183183
}
184+
185+
/**
186+
* @brief Calculates the byte offset for a specific row in the packed activation
187+
* buffer.
188+
*
189+
* @param m_idx The row index for which to calculate the offset.
190+
* @param k The K dimension (width) of the activation matrix.
191+
* @return The byte offset from the start of the buffer.
192+
*/
193+
inline size_t packed_activations_offset(int m_idx, int k) {
194+
// For a simple padded row-major format, the offset is just m_idx * k.
195+
return sizeof(float) * m_idx * k;
196+
}
197+
198+
/**
199+
* @brief Calculates the byte offset for a given column index in the packed
200+
* weights buffer. The buffer is assumed to be laid out as a series of
201+
* contiguous blocks, where each block contains `nr` packed columns.
202+
*
203+
* @param n_idx The starting column index of the tile. Must be a multiple of
204+
* `nr`.
205+
* @param k The inner dimension of the matrix.
206+
* @param weight_nbit The number of bits for the quantized weights.
207+
* @param has_scales Whether weight scales are present.
208+
* @param has_bias Whether a bias vector is packed.
209+
* @param nr The micro-kernel tiling parameter for the N dimension.
210+
* @param kr The micro-kernel tiling parameter for the K dimension.
211+
* @return The byte offset into the packed weights buffer.
212+
*/
213+
inline size_t packed_weights_offset(
214+
int n_idx,
215+
int k,
216+
int weight_nbit,
217+
int scale_group_size,
218+
bool has_scales,
219+
bool has_bias,
220+
int nr,
221+
int kr,
222+
int sr) {
223+
(void)sr; // unused
224+
assert(n_idx % nr == 0);
225+
226+
const size_t packed_tile_size_for_nr_cols = packed_weights_size(
227+
/*n=*/nr, // The size we are calculating is for a single tile of width
228+
// `nr`.
229+
k,
230+
weight_nbit,
231+
scale_group_size,
232+
has_scales,
233+
has_bias,
234+
nr,
235+
kr,
236+
sr);
237+
238+
return (n_idx / nr) * packed_tile_size_for_nr_cols;
239+
}
184240
} // namespace
185241
// torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut
186242

0 commit comments

Comments
 (0)