|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | + |
| 9 | + |
| 10 | +class SplitLinearModule(torch.nn.Module): |
| 11 | + def __init__( |
| 12 | + self, |
| 13 | + in_features, |
| 14 | + out_features, |
| 15 | + out_target_split_size=1, |
| 16 | + out_max_splits=1, |
| 17 | + in_target_split_size=1, |
| 18 | + in_max_splits=1, |
| 19 | + ): |
| 20 | + super(SplitLinearModule, self).__init__() |
| 21 | + self.out_split_sizes = self._get_split_sizes( |
| 22 | + out_features, out_target_split_size, out_max_splits |
| 23 | + ) |
| 24 | + self.in_split_sizes = self._get_split_sizes( |
| 25 | + in_features, in_target_split_size, in_max_splits |
| 26 | + ) |
| 27 | + print( |
| 28 | + f"Splitting out_features={out_features} into {len(self.out_split_sizes)} of size {self.out_split_sizes[0]}." |
| 29 | + ) |
| 30 | + print( |
| 31 | + f"Splitting in_features={in_features} into {len(self.in_split_sizes)} of size {self.in_split_sizes[0]}." |
| 32 | + ) |
| 33 | + |
| 34 | + # self.ops contains a list of linear ops for different pieces of the output matrix |
| 35 | + # The index of an op at (in_idx, out_idx) is given by self.op_index(in_idx, out_idx) |
| 36 | + self.ops = torch.nn.ModuleList() |
| 37 | + for idx_out, s_out in enumerate(self.out_split_sizes): |
| 38 | + for idx_in, s_in in enumerate(self.in_split_sizes): |
| 39 | + assert len(self.ops) == self.op_index(idx_in, idx_out) |
| 40 | + self.ops.append(torch.nn.Linear(s_in, s_out, bias=False)) |
| 41 | + |
| 42 | + def op_index(self, in_index, out_index): |
| 43 | + idx = out_index * len(self.in_split_sizes) + in_index |
| 44 | + return idx |
| 45 | + |
| 46 | + def _get_split_sizes(self, n_features, target_split_size, max_splits): |
| 47 | + num_splits = max(n_features // target_split_size, 1) |
| 48 | + if num_splits > max_splits: |
| 49 | + num_splits = max_splits |
| 50 | + |
| 51 | + split_size = n_features // num_splits |
| 52 | + split_remainder = n_features % num_splits |
| 53 | + if split_remainder > 0: |
| 54 | + raise ValueError( |
| 55 | + f"Cannot split {n_features} with target_split_size={target_split_size} and max_splits={max_splits} because it leaves a remainder of {split_remainder}." |
| 56 | + ) |
| 57 | + |
| 58 | + ret = [split_size for _ in range(num_splits)] |
| 59 | + return ret |
| 60 | + |
| 61 | + def set_params(self, weight): |
| 62 | + split_weights = [] |
| 63 | + for w_out in weight.split(self.out_split_sizes, dim=0): |
| 64 | + for w in w_out.split(self.in_split_sizes, dim=1): |
| 65 | + split_weights.append(w) |
| 66 | + |
| 67 | + for i, split in enumerate(self.ops): |
| 68 | + split.weight = torch.nn.Parameter(split_weights[i]) |
| 69 | + |
| 70 | + def forward(self, x): |
| 71 | + if len(self.in_split_sizes) == 1: |
| 72 | + out_chunks = [op(x) for op in self.ops] |
| 73 | + else: |
| 74 | + x_splits = x.split(self.in_split_sizes, dim=-1) |
| 75 | + out_chunks = [ |
| 76 | + torch.sum( |
| 77 | + torch.stack( |
| 78 | + [ |
| 79 | + self.ops[self.op_index(in_idx, out_idx)].forward( |
| 80 | + x_splits[in_idx] |
| 81 | + ) |
| 82 | + for in_idx in range(len(self.in_split_sizes)) |
| 83 | + ], |
| 84 | + ), |
| 85 | + dim=0, |
| 86 | + ) |
| 87 | + for out_idx in range(len(self.out_split_sizes)) |
| 88 | + ] |
| 89 | + |
| 90 | + return torch.concat(out_chunks, dim=-1) |
| 91 | + |
| 92 | + |
| 93 | +def replace_linear_with_split_linear( |
| 94 | + model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1 |
| 95 | +): |
| 96 | + for name, module in model.named_children(): |
| 97 | + if isinstance(module, torch.nn.Linear): |
| 98 | + assert module.bias is None, "SplitLinearModule does not support bias" |
| 99 | + new_module = SplitLinearModule( |
| 100 | + module.in_features, |
| 101 | + module.out_features, |
| 102 | + out_target_split_size, |
| 103 | + out_max_splits, |
| 104 | + in_target_split_size, |
| 105 | + in_max_splits, |
| 106 | + ) |
| 107 | + new_module.set_params(module.weight) |
| 108 | + setattr(model, name, new_module) |
| 109 | + else: |
| 110 | + replace_linear_with_split_linear( |
| 111 | + module, |
| 112 | + out_target_split_size, |
| 113 | + out_max_splits, |
| 114 | + in_target_split_size, |
| 115 | + in_max_splits, |
| 116 | + ) |
0 commit comments