|
14 | 14 | dtype=torch.float16 |
15 | 15 |
|
16 | 16 |
|
| 17 | + |
| 18 | + |
| 19 | + |
| 20 | + |
| 21 | + |
| 22 | + |
17 | 23 | def SparsifyIndices( |
18 | 24 | x: torch.Tensor, ws: List[int], rs: List[int], head_idx: int |
19 | 25 | ) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]: |
@@ -104,6 +110,45 @@ def MixOutputs( |
104 | 110 |
|
105 | 111 |
|
106 | 112 |
|
| 113 | + |
| 114 | +class ParallelWrapper: |
| 115 | + """ |
| 116 | + A simple wrapper to enable easy usage of data parallelism. |
| 117 | +
|
| 118 | + Arguments: |
| 119 | + model: The neural network model to be parallelized. |
| 120 | + device (optional): The device to which the model should be moved. Default: "cuda". |
| 121 | + use_data_parallel (optional): A boolean flag to indicate whether to use data parallelism or not. Default: True. |
| 122 | + """ |
| 123 | + def __init__( |
| 124 | + self, |
| 125 | + model, |
| 126 | + device="cuda", |
| 127 | + use_data_parallel=True |
| 128 | + ): |
| 129 | + self.model = model.to(device) |
| 130 | + self.use_data_parallel = use_data_parallel |
| 131 | + self.device = device |
| 132 | + |
| 133 | + if self.use_data_parallel and torch.cuda.device_count() < 1: |
| 134 | + print(f"Using {torch.cuda.device_count()} GPUS") |
| 135 | + self.model = nn.DataParallel(self.model) |
| 136 | + |
| 137 | + def forward(self, *args, **kwargs): |
| 138 | + return self.model(*args, **kwargs) |
| 139 | + |
| 140 | + def to(self, device): |
| 141 | + self.device = device |
| 142 | + self.model= self.model.to(device) |
| 143 | + return self |
| 144 | + |
| 145 | + def __getattr__(self, name): |
| 146 | + #redirect attribute access to the internal model to allow direct access to its methods and props |
| 147 | + return getattr(self.model, name) |
| 148 | + |
| 149 | + |
| 150 | + |
| 151 | + |
107 | 152 | #add alibi, qk layer norm, one write head, multiway, |
108 | 153 | class DilatedAttentionNew(nn.Module): |
109 | 154 | """ |
@@ -319,6 +364,22 @@ def forward(self, x): |
319 | 364 |
|
320 | 365 |
|
321 | 366 |
|
| 367 | + |
| 368 | + |
| 369 | + |
| 370 | + |
| 371 | + |
| 372 | + |
| 373 | + |
| 374 | + |
| 375 | + |
| 376 | + |
| 377 | + |
| 378 | + |
| 379 | + |
| 380 | + |
| 381 | + |
| 382 | + |
322 | 383 |
|
323 | 384 |
|
324 | 385 | class MultiHeadDilatedAttention: |
|
0 commit comments