Skip to content

Commit e60d648

Browse files
authored
Feat (equalize): option to disable fused block rotations (#1438)
1 parent dbb6e37 commit e60d648

File tree

4 files changed

+89
-28
lines changed

4 files changed

+89
-28
lines changed

src/brevitas/graph/equalize.py

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,13 +1564,64 @@ def random_orthogonal_matrix(size):
15641564
return q
15651565

15661566

1567+
def _compute_hidden_dim(
1568+
region: Region,
1569+
block_rotation_dim: Optional[int] = None,
1570+
insert_rotation_module: bool = False,
1571+
disable_block_rotation_for_fused: bool = False,
1572+
expansion_step: int = 1) -> int:
1573+
"""
1574+
Compute the hidden dimension for rotation per region.
1575+
1576+
Since each region may have a different shape and block rotation compatibility,
1577+
this calculation must be performed on a per-region basis. The order of operations
1578+
is: (1) initialize from region, (2) expand if needed, (3) apply block rotation.
1579+
1580+
Args:
1581+
region: The region for which to compute hidden dimension.
1582+
block_rotation_dim: Optional block rotation dimension for block-wise rotations.
1583+
insert_rotation_module: Whether this region is an orphan sink (online rotation).
1584+
disable_block_rotation_for_fused: Whether to disable block rotation for fused rotations.
1585+
expansion_step: Number of steps for finding closest hadamard number (used for expansion).
1586+
1587+
Returns:
1588+
The computed hidden dimension for the region.
1589+
"""
1590+
# Step 1: Initialize hidden_dim to the max shape of the sinks
1591+
hidden_dim = region.max_shape_sinks
1592+
1593+
# Step 2: Expand if region requires expansion
1594+
if region.expand_region:
1595+
hidden_dim = find_closest_hadamard_number(hidden_dim, steps=expansion_step)
1596+
1597+
# Step 3: Apply block rotation if applicable
1598+
apply_block_rotation = block_rotation_dim is not None
1599+
1600+
# insert_rotation_module is True for orphan sinks (aka online rotations). Sometimes we want to use
1601+
# block rotations only for online rotations and full-vector for fused rotations.
1602+
if not insert_rotation_module and disable_block_rotation_for_fused:
1603+
apply_block_rotation = False
1604+
1605+
if apply_block_rotation:
1606+
# Check block_rotation is compatible with the current shape
1607+
if (hidden_dim // block_rotation_dim > 1) and (hidden_dim % block_rotation_dim == 0):
1608+
hidden_dim = block_rotation_dim
1609+
else:
1610+
logging.info(
1611+
f"Block rotation shape is not compatible with hidden_dim={hidden_dim}."
1612+
" Falling back to full-vector rotation.")
1613+
1614+
return hidden_dim
1615+
1616+
15671617
def _compute_rotations(
15681618
model: nn.Module,
15691619
regions: List[Region],
15701620
full_rotation_method='had',
15711621
fuse_rotations: bool = True,
15721622
expansion_step: int = 1,
1573-
block_rotation_dim: Optional[int] = None):
1623+
block_rotation_dim: Optional[int] = None,
1624+
disable_block_rotation_for_fused: bool = False):
15741625

15751626
rewriters = []
15761627
# First, rotations on orphan sinks are applied so the order in which rotations are
@@ -1589,9 +1640,22 @@ def _compute_rotations(
15891640
assert not region.expand_region, "Orthogonal rotation not compatible with expansion"
15901641
assert block_rotation_dim is None, "Orthogonal rotation not compatible with blockwise rotation"
15911642

1592-
# Initialize variables
1593-
hidden_dim = region.max_shape_sinks
1594-
expanded_hidden_dim, expanded_rot_mat, expanded_K = None, None, None
1643+
# Compute hidden_dim per region (includes expansion if applicable)
1644+
hidden_dim = _compute_hidden_dim(
1645+
region=region,
1646+
block_rotation_dim=block_rotation_dim,
1647+
insert_rotation_module=insert_rotation_module,
1648+
disable_block_rotation_for_fused=disable_block_rotation_for_fused,
1649+
expansion_step=expansion_step)
1650+
1651+
# NOTE: We need to compute expanded_hidden_dim separately for weight padding in expansion
1652+
# regions. This is required for proper interop of block rotations and expansion: the weight
1653+
# must be padded to the expanded dimension before block reduction, but hidden_dim may have
1654+
# been reduced by block rotation for the parametrizations and rotation matrices.
1655+
if region.expand_region:
1656+
expanded_hidden_dim = find_closest_hadamard_number(
1657+
region.max_shape_sinks, steps=expansion_step)
1658+
15951659
if not insert_rotation_module and full_rotation_method == 'ort':
15961660
rot_mat = random_orthogonal_matrix(hidden_dim)
15971661
rot_func = _apply_ort_device
@@ -1606,8 +1670,6 @@ def _compute_rotations(
16061670
try:
16071671
# Build hadamard rotation matrix
16081672
rot_mat, K = get_hadK(hidden_dim)
1609-
expanded_hidden_dim = find_closest_hadamard_number(hidden_dim, steps=expansion_step)
1610-
expanded_rot_mat, expanded_K = get_hadK(int(expanded_hidden_dim))
16111673
rot_func = _apply_had_device
16121674
except AssertionError as e:
16131675
logging.info(f"Incompatible dim {hidden_dim} for hadamard rotation")
@@ -1619,22 +1681,6 @@ def _compute_rotations(
16191681
logging.info("Skipping region")
16201682
continue
16211683

1622-
hidden_dim = hidden_dim if not region.expand_region else expanded_hidden_dim
1623-
# Check if we are doing block_rotation and if it is compatible with the current shape
1624-
if block_rotation_dim is not None:
1625-
if hidden_dim // block_rotation_dim > 1 and hidden_dim % block_rotation_dim == 0:
1626-
hidden_dim = block_rotation_dim
1627-
rot_mat, K = get_hadK(block_rotation_dim)
1628-
if region.expand_region:
1629-
expanded_rot_mat, expanded_K = rot_mat, K
1630-
else:
1631-
logging.info(
1632-
"Block rotation shape is not compatible with the region shape, perfoming normal rotations"
1633-
)
1634-
1635-
if region.expand_region:
1636-
rot_mat, K = expanded_rot_mat, expanded_K
1637-
16381684
# Cast rotation matrix to the weight dtype
16391685
if rot_mat is not None:
16401686
dtype = next(model.parameters()).dtype
@@ -1896,6 +1942,7 @@ def __init__(
18961942
use_parametrized_rotations: bool = False,
18971943
full_rotation_method: str = 'had',
18981944
block_rotation_dim: Optional[int] = None,
1945+
disable_block_rotation_for_fused: bool = False,
18991946
layers_to_expand: Optional[List[str]] = None,
19001947
expansion_step: int = None,
19011948
delay_rewriters: bool = False,
@@ -1921,6 +1968,7 @@ def __init__(
19211968
self.expansion_step = expansion_step
19221969
self.delay_rewriters = delay_rewriters
19231970
self.block_rotation_dim = block_rotation_dim
1971+
self.disable_block_rotation_for_fused = disable_block_rotation_for_fused
19241972

19251973
if self.delay_rewriters:
19261974
assert return_rewriters, "If `delay_rewriters=True`, rewriters are not applied immediately. Therefore, these must be returned, by setting `return_rewriters=True`, to be applied at a later stage."
@@ -2098,15 +2146,17 @@ def apply(self,
20982146
self.full_rotation_method,
20992147
fuse_rotations=not self.use_parametrized_rotations,
21002148
expansion_step=first_exp_step,
2101-
block_rotation_dim=self.block_rotation_dim))
2149+
block_rotation_dim=self.block_rotation_dim,
2150+
disable_block_rotation_for_fused=self.disable_block_rotation_for_fused))
21022151
rewriters.extend(
21032152
_compute_rotations(
21042153
graph_model,
21052154
second_set,
21062155
self.full_rotation_method,
21072156
fuse_rotations=not self.use_parametrized_rotations,
21082157
expansion_step=second_exp_step,
2109-
block_rotation_dim=self.block_rotation_dim))
2158+
block_rotation_dim=self.block_rotation_dim,
2159+
disable_block_rotation_for_fused=self.disable_block_rotation_for_fused))
21102160
if len(expanded_regions) > 0:
21112161
parameter_number_post = 0
21122162
for m in graph_model.parameters():

src/brevitas_examples/llm/llm_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,10 @@ def create_args_parser() -> ArgumentParser:
365365
type=int,
366366
default=None,
367367
help='Perform blockwise rotations when possible. Default: %(default)s')
368+
parser.add_argument(
369+
'--disable-block-rotation-for-fused',
370+
action='store_true',
371+
help='Disable block rotations when using fused rotations. Default: %(default)s')
368372
parser.add_argument('--svd-quant', action='store_true', help='Apply SVDQuant.')
369373
parser.add_argument(
370374
'--svd-quant-rank',

src/brevitas_examples/llm/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def fused_rotation_no_fx(model, calibration_loader, args):
130130
expansion_step=args.expansion_step,
131131
layers_to_expand=layers_to_expand,
132132
block_rotation_dim=args.block_rotation_dim,
133+
disable_block_rotation_for_fused=args.disable_block_rotation_for_fused,
133134
extra_state_kwargs={'scale_invariant_layers': rmsnorm_classes})
134135
fx_model, rewriters = eq.apply(fx_model)
135136

@@ -347,6 +348,7 @@ def quantize_llm(args, extra_args=None):
347348
expansion_step=args.expansion_step,
348349
layers_to_expand=layers_to_expand,
349350
block_rotation_dim=args.block_rotation_dim,
351+
disable_block_rotation_for_fused=args.disable_block_rotation_for_fused,
350352
extra_state_kwargs={'scale_invariant_layers': rmsnorm_classes})
351353
model = eq.apply(model)
352354
remove_hooks(model)

tests/brevitas/graph/test_equalization.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def compare_model_weights(model_fused, model_unfused, classes_to_compare=(nn.Lin
435435
@pytest_cases.parametrize('use_fx', [True, False], ids=["fx", "no-fx"])
436436
@pytest_cases.parametrize('expansion_step', [3, 0], ids=["expansion", "no-expansion"])
437437
@pytest_cases.parametrize('block_rotation_dim', [12, None])
438+
@pytest_cases.parametrize('disable_block_rotation_for_fused', [True, False])
438439
def test_compute_rotations(
439440
rotation_model,
440441
mask,
@@ -443,7 +444,8 @@ def test_compute_rotations(
443444
fuse_rotations,
444445
use_fx,
445446
expansion_step,
446-
block_rotation_dim):
447+
block_rotation_dim,
448+
disable_block_rotation_for_fused):
447449
if expansion_step > 0 and full_rotation_method == 'ort':
448450
pytest.skip("Expansion is not compatible with orthogonal rotations")
449451
if block_rotation_dim is not None and full_rotation_method == 'ort':
@@ -518,7 +520,8 @@ def patched_function(tensor, had_K, K):
518520
full_rotation_method=full_rotation_method,
519521
fuse_rotations=False,
520522
expansion_step=expansion_step,
521-
block_rotation_dim=block_rotation_dim)
523+
block_rotation_dim=block_rotation_dim,
524+
disable_block_rotation_for_fused=disable_block_rotation_for_fused)
522525
elif full_rotation_method == 'ort':
523526
with patch('brevitas.graph.equalize.random_orthogonal_matrix',
524527
partial(_random_orthogonal_matrix, generator=generator)):
@@ -528,7 +531,8 @@ def patched_function(tensor, had_K, K):
528531
full_rotation_method=full_rotation_method,
529532
fuse_rotations=False,
530533
expansion_step=expansion_step,
531-
block_rotation_dim=block_rotation_dim)
534+
block_rotation_dim=block_rotation_dim,
535+
disable_block_rotation_for_fused=disable_block_rotation_for_fused)
532536

533537
apply_rewriters(rotated_model_unfused, rewriters)
534538

@@ -545,7 +549,8 @@ def patched_function(tensor, had_K, K):
545549
full_rotation_method=full_rotation_method,
546550
fuse_rotations=True,
547551
expansion_step=expansion_step,
548-
block_rotation_dim=block_rotation_dim)
552+
block_rotation_dim=block_rotation_dim,
553+
disable_block_rotation_for_fused=disable_block_rotation_for_fused)
549554
apply_rewriters(rotated_model_fused, r)
550555

551556
# Compute outputs for each model

0 commit comments

Comments
 (0)