@@ -1564,13 +1564,64 @@ def random_orthogonal_matrix(size):
15641564 return q
15651565
15661566
1567+ def _compute_hidden_dim (
1568+ region : Region ,
1569+ block_rotation_dim : Optional [int ] = None ,
1570+ insert_rotation_module : bool = False ,
1571+ disable_block_rotation_for_fused : bool = False ,
1572+ expansion_step : int = 1 ) -> int :
1573+ """
1574+ Compute the hidden dimension for rotation per region.
1575+
1576+ Since each region may have a different shape and block rotation compatibility,
1577+ this calculation must be performed on a per-region basis. The order of operations
1578+ is: (1) initialize from region, (2) expand if needed, (3) apply block rotation.
1579+
1580+ Args:
1581+ region: The region for which to compute hidden dimension.
1582+ block_rotation_dim: Optional block rotation dimension for block-wise rotations.
1583+ insert_rotation_module: Whether this region is an orphan sink (online rotation).
1584+ disable_block_rotation_for_fused: Whether to disable block rotation for fused rotations.
1585+ expansion_step: Number of steps for finding closest hadamard number (used for expansion).
1586+
1587+ Returns:
1588+ The computed hidden dimension for the region.
1589+ """
1590+ # Step 1: Initialize hidden_dim to the max shape of the sinks
1591+ hidden_dim = region .max_shape_sinks
1592+
1593+ # Step 2: Expand if region requires expansion
1594+ if region .expand_region :
1595+ hidden_dim = find_closest_hadamard_number (hidden_dim , steps = expansion_step )
1596+
1597+ # Step 3: Apply block rotation if applicable
1598+ apply_block_rotation = block_rotation_dim is not None
1599+
1600+ # insert_rotation_module is True for orphan sinks (aka online rotations). Sometimes we want to use
1601+ # block rotations only for online rotations and full-vector for fused rotations.
1602+ if not insert_rotation_module and disable_block_rotation_for_fused :
1603+ apply_block_rotation = False
1604+
1605+ if apply_block_rotation :
1606+ # Check block_rotation is compatible with the current shape
1607+ if (hidden_dim // block_rotation_dim > 1 ) and (hidden_dim % block_rotation_dim == 0 ):
1608+ hidden_dim = block_rotation_dim
1609+ else :
1610+ logging .info (
1611+ f"Block rotation shape is not compatible with hidden_dim={ hidden_dim } ."
1612+ " Falling back to full-vector rotation." )
1613+
1614+ return hidden_dim
1615+
1616+
15671617def _compute_rotations (
15681618 model : nn .Module ,
15691619 regions : List [Region ],
15701620 full_rotation_method = 'had' ,
15711621 fuse_rotations : bool = True ,
15721622 expansion_step : int = 1 ,
1573- block_rotation_dim : Optional [int ] = None ):
1623+ block_rotation_dim : Optional [int ] = None ,
1624+ disable_block_rotation_for_fused : bool = False ):
15741625
15751626 rewriters = []
15761627 # First, rotations on orphan sinks are applied so the order in which rotations are
@@ -1589,9 +1640,22 @@ def _compute_rotations(
15891640 assert not region .expand_region , "Orthogonal rotation not compatible with expansion"
15901641 assert block_rotation_dim is None , "Orthogonal rotation not compatible with blockwise rotation"
15911642
1592- # Initialize variables
1593- hidden_dim = region .max_shape_sinks
1594- expanded_hidden_dim , expanded_rot_mat , expanded_K = None , None , None
1643+ # Compute hidden_dim per region (includes expansion if applicable)
1644+ hidden_dim = _compute_hidden_dim (
1645+ region = region ,
1646+ block_rotation_dim = block_rotation_dim ,
1647+ insert_rotation_module = insert_rotation_module ,
1648+ disable_block_rotation_for_fused = disable_block_rotation_for_fused ,
1649+ expansion_step = expansion_step )
1650+
1651+ # NOTE: We need to compute expanded_hidden_dim separately for weight padding in expansion
1652+ # regions. This is required for proper interop of block rotations and expansion: the weight
1653+ # must be padded to the expanded dimension before block reduction, but hidden_dim may have
1654+ # been reduced by block rotation for the parametrizations and rotation matrices.
1655+ if region .expand_region :
1656+ expanded_hidden_dim = find_closest_hadamard_number (
1657+ region .max_shape_sinks , steps = expansion_step )
1658+
15951659 if not insert_rotation_module and full_rotation_method == 'ort' :
15961660 rot_mat = random_orthogonal_matrix (hidden_dim )
15971661 rot_func = _apply_ort_device
@@ -1606,8 +1670,6 @@ def _compute_rotations(
16061670 try :
16071671 # Build hadamard rotation matrix
16081672 rot_mat , K = get_hadK (hidden_dim )
1609- expanded_hidden_dim = find_closest_hadamard_number (hidden_dim , steps = expansion_step )
1610- expanded_rot_mat , expanded_K = get_hadK (int (expanded_hidden_dim ))
16111673 rot_func = _apply_had_device
16121674 except AssertionError as e :
16131675 logging .info (f"Incompatible dim { hidden_dim } for hadamard rotation" )
@@ -1619,22 +1681,6 @@ def _compute_rotations(
16191681 logging .info ("Skipping region" )
16201682 continue
16211683
1622- hidden_dim = hidden_dim if not region .expand_region else expanded_hidden_dim
1623- # Check if we are doing block_rotation and if it is compatible with the current shape
1624- if block_rotation_dim is not None :
1625- if hidden_dim // block_rotation_dim > 1 and hidden_dim % block_rotation_dim == 0 :
1626- hidden_dim = block_rotation_dim
1627- rot_mat , K = get_hadK (block_rotation_dim )
1628- if region .expand_region :
1629- expanded_rot_mat , expanded_K = rot_mat , K
1630- else :
1631- logging .info (
1632- "Block rotation shape is not compatible with the region shape, perfoming normal rotations"
1633- )
1634-
1635- if region .expand_region :
1636- rot_mat , K = expanded_rot_mat , expanded_K
1637-
16381684 # Cast rotation matrix to the weight dtype
16391685 if rot_mat is not None :
16401686 dtype = next (model .parameters ()).dtype
@@ -1896,6 +1942,7 @@ def __init__(
18961942 use_parametrized_rotations : bool = False ,
18971943 full_rotation_method : str = 'had' ,
18981944 block_rotation_dim : Optional [int ] = None ,
1945+ disable_block_rotation_for_fused : bool = False ,
18991946 layers_to_expand : Optional [List [str ]] = None ,
19001947 expansion_step : int = None ,
19011948 delay_rewriters : bool = False ,
@@ -1921,6 +1968,7 @@ def __init__(
19211968 self .expansion_step = expansion_step
19221969 self .delay_rewriters = delay_rewriters
19231970 self .block_rotation_dim = block_rotation_dim
1971+ self .disable_block_rotation_for_fused = disable_block_rotation_for_fused
19241972
19251973 if self .delay_rewriters :
19261974 assert return_rewriters , "If `delay_rewriters=True`, rewriters are not applied immediately. Therefore, these must be returned, by setting `return_rewriters=True`, to be applied at a later stage."
@@ -2098,15 +2146,17 @@ def apply(self,
20982146 self .full_rotation_method ,
20992147 fuse_rotations = not self .use_parametrized_rotations ,
21002148 expansion_step = first_exp_step ,
2101- block_rotation_dim = self .block_rotation_dim ))
2149+ block_rotation_dim = self .block_rotation_dim ,
2150+ disable_block_rotation_for_fused = self .disable_block_rotation_for_fused ))
21022151 rewriters .extend (
21032152 _compute_rotations (
21042153 graph_model ,
21052154 second_set ,
21062155 self .full_rotation_method ,
21072156 fuse_rotations = not self .use_parametrized_rotations ,
21082157 expansion_step = second_exp_step ,
2109- block_rotation_dim = self .block_rotation_dim ))
2158+ block_rotation_dim = self .block_rotation_dim ,
2159+ disable_block_rotation_for_fused = self .disable_block_rotation_for_fused ))
21102160 if len (expanded_regions ) > 0 :
21112161 parameter_number_post = 0
21122162 for m in graph_model .parameters ():
0 commit comments