Skip to content

Commit d7f847c

Browse files
RachelXu7minghaoBD
andauthored
speedup sparse block (PaddlePaddle#1022)
* speedup_sparse_block * speedup_sparse_block * speedup_sparse_block Co-authored-by: minghaoBD <[email protected]>
1 parent 46751fe commit d7f847c

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

paddleslim/prune/unstructured_pruner_utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,25 @@
88

99
def cal_mxn_avg_matrix(mat, m=1, n=1):
1010
if m == 1 and n == 1: return copy.deepcopy(mat)
11+
12+
ori_row, ori_col = mat.shape[0], mat.shape[1]
13+
if len(mat.shape) == 4:
14+
assert mat.shape[2:] == (1, 1), "Only support for (n, n, 1, 1) for now."
15+
mat = mat.reshape(ori_row, ori_col)
16+
17+
res_col = n - len(mat[0]) % n
18+
res_row = m - len(mat) % m
19+
20+
mat = np.pad(mat, ((0, res_col), (0, res_col)), 'reflect')
1121
avg_mat = np.zeros_like(mat)
12-
rows = len(mat) // m + 1
13-
cols = len(mat[0]) // n + 1
14-
for row in range(rows):
15-
for col in range(cols):
16-
avg_mat[m * row:m * row + m, n * col:n * col + n] = np.mean(mat[
17-
m * row:m * row + m, n * col:n * col + n])
22+
new_shape = [len(mat) // m, len(mat[0]) // n, m, n]
23+
strides = mat.itemsize * np.array([len(mat) * m, n, len(mat), 1])
24+
mat = np.lib.stride_tricks.as_strided(mat, shape=new_shape, strides=strides)
25+
mat = mat.mean((2, 3), keepdims=True)
26+
mat = np.tile(mat, (1, 1, m, n))
27+
for i in range(len(mat)):
28+
avg_mat[i * m:i * m + m] = np.concatenate(list(mat[i]), axis=1)
29+
avg_mat = avg_mat[:ori_row, :ori_col]
30+
if len(mat.shape) == 4:
31+
avg_mat = avg_mat.reshape(ori_row, ori_col, 1, 1)
1832
return avg_mat

0 commit comments

Comments
 (0)