Skip to content

Commit 657482e

Browse files
committed
reorganize sparsity module to separate weight and attention sparsity
1 parent 1537885 commit 657482e

File tree

10 files changed

+1115
-0
lines changed

10 files changed

+1115
-0
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""API for weight sparsification algorithms."""
17+
18+
from . import mode, module, plugins
19+
20+
# Explicitly expose commonly used items
21+
from .mode import SparsityModeRegistry
22+
from .module import SparseModule, SpDMRegistry
23+
from .sparsification import *
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Default configurations for sparsity modes."""
17+
18+
from pydantic import create_model
19+
20+
from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules
21+
22+
from .module import SpDMRegistry
23+
24+
SparseMagnitudeConfig = create_model(
25+
"SparseMagnitudeConfig",
26+
**get_kwargs_for_create_model_with_rules(
27+
registry=SpDMRegistry,
28+
default_rules={
29+
"nn.Linear": {"*": {}, "*lm_head*": None},
30+
"nn.Conv2d": {"*": {}, "*lm_head*": None},
31+
},
32+
doc='Configuration for the ``"sparse_magnitude"`` mode.',
33+
),
34+
)
35+
36+
37+
SparseGPTConfig = create_model(
38+
"SparseGPTConfig",
39+
**get_kwargs_for_create_model_with_rules(
40+
registry=SpDMRegistry,
41+
default_rules={
42+
"nn.Linear": {"*": {}, "*lm_head*": None},
43+
"nn.Conv2d": {"*": {}, "*lm_head*": None},
44+
},
45+
doc='Configuration for the ``"sparse_gpt"`` mode.',
46+
),
47+
)
48+
49+
50+
class ExportSparseConfig(ModeloptBaseConfig):
51+
"""Configuration (empty!) for the ``"export_sparse"`` mode."""
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Magnitude-base sparsity inspired by NVIDIA ASP (Automatic SParsity)."""
17+
18+
import re
19+
import warnings
20+
from itertools import permutations
21+
22+
import torch
23+
import torch.nn as nn
24+
25+
from .module import SparseModule
26+
from .searcher import BaseSparseSearcher
27+
28+
29+
def get_nmprune_info(pattern: str) -> tuple[bool, int, int]:
30+
"""Gets the n:m sparsity pattern information from a given string."""
31+
nm_prune = re.search(r"(\d+):(\d+) sparsity", pattern)
32+
if nm_prune is not None:
33+
n, m = map(int, nm_prune.groups())
34+
return nm_prune is not None, n, m
35+
return False, 0, 0
36+
37+
38+
def fill(x):
39+
"""Calculates the ratio of non-zero elements in a tensor."""
40+
return float(x.nonzero().size(0)) / torch.numel(x)
41+
42+
43+
def reshape_1d(matrix, m):
44+
"""Reshapes a given matrix into m-dimensional vectors: (h,w) -> (hw/m, m)."""
45+
if matrix.shape[1] % m > 0:
46+
new_cols = matrix.shape[1] + (m - matrix.shape[1] % m)
47+
mat = matrix.new_empty(matrix.shape[0], new_cols).fill_(0)
48+
mat[:, : matrix.shape[1]] = matrix
49+
50+
return mat.view(-1, m), mat.shape
51+
else:
52+
return matrix.view(-1, m), matrix.shape
53+
54+
55+
def compute_valid_1d_patterns(m, n):
56+
"""Computes all possible m:n patterns in a 1D vector.
57+
58+
The function generates a tensor of size m with n ones and (m-n) zeros.
59+
It then generates all permutations of this tensor, removes duplicates,
60+
and returns the unique patterns as a tensor.
61+
"""
62+
patterns = torch.zeros(m)
63+
patterns[:n] = 1
64+
valid_patterns = torch.tensor(list(set(permutations(patterns.tolist()))))
65+
return valid_patterns
66+
67+
68+
def mn_1d_best(matrix, m, n):
69+
"""Finds the best m:n pattern in a given matrix.
70+
71+
The function computes all possible m:n patterns and selects the one
72+
that maximizes the sum of non-masked weights in the matrix. The selected
73+
pattern is then used to create a mask for the matrix.
74+
"""
75+
patterns = compute_valid_1d_patterns(m, n).to(matrix.device)
76+
77+
# Find the best m:n pattern (sum of non-masked weights).
78+
mask = torch.IntTensor(matrix.shape).fill_(1).view(-1, m)
79+
mat, _ = reshape_1d(matrix, m)
80+
pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)
81+
mask[:] = patterns[pmax[:]]
82+
mask = mask.view(matrix.shape)
83+
return mask
84+
85+
86+
def m4n2_1d(mat):
87+
"""Finds the best 2:4 pattern in a given matrix."""
88+
return mn_1d_best(mat, 4, 2)
89+
90+
91+
def create_asp_mask(tensor: nn.Parameter, pattern: str) -> torch.BoolTensor:
92+
"""Creates a mask for a given tensor based on a specified sparse pattern.
93+
94+
The function reshapes the tensor and applies the specified pattern to create a sparse mask.
95+
The default pattern is m4n2_1d, which finds the best 2:4 sparsity pattern in the tensor.
96+
"""
97+
pattern_method_lut = {BaseSparseSearcher._pattern_2_4: m4n2_1d}
98+
if pattern not in pattern_method_lut:
99+
raise NotImplementedError(f"Unsupported pattern {pattern} for ASP sparsity")
100+
func = pattern_method_lut[pattern]
101+
102+
shape = tensor.shape
103+
tensor.type()
104+
t = tensor.float().contiguous()
105+
106+
# 1d-tensor
107+
if len(shape) == 1:
108+
t = t.view(1, shape[0])
109+
mask = func(t)
110+
# 2d-tensor (K, C)
111+
elif len(shape) == 2:
112+
# linear
113+
t = t.view(shape[0], shape[1])
114+
mask = func(t)
115+
# 3d-tensor (K, C, R)
116+
elif len(shape) == 3:
117+
# 1d convs
118+
t = t.permute(0, 2, 1).contiguous().view(shape[0] * shape[2], shape[1])
119+
mask = func(t)
120+
mask = mask.view(shape[0], shape[2], shape[1]).permute(0, 2, 1).contiguous()
121+
# 4d-tensor (K, C, R, S)
122+
elif len(shape) == 4:
123+
# 2d convs
124+
t = t.permute(2, 3, 0, 1).contiguous().view(shape[2] * shape[3] * shape[0], shape[1])
125+
mask = func(t)
126+
mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2, 3, 0, 1).contiguous()
127+
128+
return mask.view(shape).to(dtype=torch.bool)
129+
130+
131+
class MagnitudeSearcher(BaseSparseSearcher):
132+
"""Searcher for magnitude-based sparsity."""
133+
134+
def _check_weight_size(self, weight: torch.nn.Parameter, mod_name: str) -> bool:
135+
"""Check if the weight size is supported."""
136+
# rules from ASP
137+
if weight.size(0) % 8 != 0 or weight.size(1) % 16 != 0:
138+
warnings.warn(
139+
f"Skipping sparsifying {mod_name} of size={weight.size()!s} and"
140+
f" type={weight.dtype!s} for sparsity"
141+
)
142+
return False
143+
144+
return True
145+
146+
def _compute_mask(self, module: SparseModule) -> torch.BoolTensor:
147+
"""Compute the mask (and weight update) for the given module."""
148+
return create_asp_mask(module.weight, self.config["pattern"])

0 commit comments

Comments
 (0)