-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathmodel.py
More file actions
747 lines (646 loc) · 26.8 KB
/
model.py
File metadata and controls
747 lines (646 loc) · 26.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Standard
from logging import getLogger
from typing import List, Optional, Union
import functools
import hashlib
import json
import logging
import os
# Third Party
from tqdm import tqdm
from transformers import AutoConfig, PretrainedConfig
from transformers.utils.hub import cached_file
import accelerate
import threadpoolctl as tctl
import torch
import torch.nn as nn
import transformers
# Local
from ..models._const import (
CPU,
CUDA_0,
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH,
EXPERT_INDEX_PLACEHOLDER,
SUPPORTED_MODELS,
)
from ..nn_modules.qlinear import BaseQuantLinear
from ..quantization import FORMAT, QuantizeConfig
from .backend import Backend
from .importer import select_quant_linear
logger = getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
def recurse_getattr(obj, attr: str):
"""
Recursive `getattr`.
Args:
obj:
A class instance holding the attribute.
attr (`str`):
The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
"""
def _getattr(obj, attr):
return getattr(obj, attr)
return functools.reduce(_getattr, [obj] + attr.split("."))
def recurse_setattr(module, name, value):
"""A function to recursively set attributes to a module."""
if "." not in name:
setattr(module, name, value)
else:
name, rest = name.split(".", 1)
recurse_setattr(getattr(module, name), rest, value)
def get_device(obj: Union[torch.Tensor, nn.Module]):
if isinstance(obj, torch.Tensor):
return obj.device
return next(obj.parameters()).device
def move_to(obj: Union[torch.Tensor, nn.Module], device: torch.device):
if get_device(obj) != device:
obj = obj.to(device)
return obj
def nested_move_to(v, device):
if isinstance(v, torch.Tensor):
return move_to(v, device)
elif isinstance(v, (list, tuple)):
return type(v)([nested_move_to(e, device) for e in v])
else:
return v
def find_layers(module, layers=None, name=""):
if not layers:
layers = [transformers.pytorch_utils.Conv1D, nn.Conv2d, nn.Linear]
for layer in layers:
if isinstance(module, layer):
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res
def get_module_by_name_prefix(model, module_name: str):
for name, module in model.named_modules():
if name.startswith(module_name):
return module
def get_module_by_name_suffix(model, module_name: str):
for name, module in model.named_modules():
if name.endswith(module_name):
return module
def make_quant(
module,
names,
bits: int,
group_size: int,
backend: Backend,
format: str,
desc_act: bool = False,
sym: bool = True,
use_cuda_fp16: bool = True,
pack: bool = False,
) -> BaseQuantLinear:
select_quant_linear_func = (
select_quant_linear_with_pack if pack else select_quant_linear
)
QuantLinear = select_quant_linear_func(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
backend=backend,
format=format,
pack=pack,
)
if isinstance(module, QuantLinear):
return QuantLinear
for name, submodule in module.named_modules():
if name in names:
ori_layer_device = next(submodule.parameters()).device
if isinstance(submodule, nn.Linear):
in_features = submodule.in_features
out_features = submodule.out_features
elif isinstance(submodule, nn.Conv2d):
in_features = submodule.in_channels
out_features = submodule.out_channels
elif isinstance(submodule, transformers.pytorch_utils.Conv1D):
in_features = submodule.weight.shape[0]
out_features = submodule.weight.shape[1]
else:
raise NotImplementedError(f"Unsupported module {submodule}")
bias = submodule.bias is not None
if (not (desc_act) or group_size == -1) and backend != Backend.TRITON:
new_layer = QuantLinear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
infeatures=in_features,
outfeatures=out_features,
bias=bias,
use_cuda_fp16=use_cuda_fp16,
weight_dtype=submodule.weight.dtype,
)
else:
new_layer = QuantLinear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
infeatures=in_features,
outfeatures=out_features,
bias=bias,
weight_dtype=submodule.weight.dtype,
)
new_layer.device = ori_layer_device
recurse_setattr(module, name, new_layer.to(ori_layer_device))
return QuantLinear
def convert_gptq_v1_to_v2_format(
model,
quantize_config: QuantizeConfig,
qlinear_kernel: nn.Module,
):
# Limit thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
for _, submodule in model.named_modules():
# v1 checkpoint format used to do `qzeros = qzeros -= 1` before serialization, thus the
# additions here do not overflow.
# v1 checkpoint format with sym=False saved via convert_gptq_v2_to_v1_format() will
# overflow ~<=13% based on testing
if isinstance(submodule, qlinear_kernel):
if quantize_config.bits == 2:
submodule.qzeros.data += 0b01010101010101010101010101010101
elif quantize_config.bits == 3:
submodule.qzeros.data[
:, range(0, submodule.qzeros.data.shape[1], 3)
] += 0b00100100100100100100100100100100
submodule.qzeros.data[
:, range(1, submodule.qzeros.data.shape[1], 3)
] += 0b10010010010010010010010010010010
submodule.qzeros.data[
:, range(2, submodule.qzeros.data.shape[1], 3)
] += 0b01001001001001001001001001001001
elif quantize_config.bits == 4:
submodule.qzeros.data += 0b00010001000100010001000100010001
elif quantize_config.bits == 8:
submodule.qzeros.data += 0b00000001000000010000000100000001
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
return model
def convert_gptq_v2_to_v1_format(
model,
quantize_config: QuantizeConfig,
qlinear_kernel: nn.Module,
):
# Limit thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
for _, submodule in model.named_modules():
# sym=False has underflow probability of ~<=13% during testing. No underflow possible for sym=True.
if isinstance(submodule, qlinear_kernel):
if quantize_config.bits == 2:
submodule.qzeros.data -= 0b01010101010101010101010101010101
elif quantize_config.bits == 3:
submodule.qzeros.data[
:, range(0, submodule.qzeros.data.shape[1], 3)
] -= 0b00100100100100100100100100100100
submodule.qzeros.data[
:, range(1, submodule.qzeros.data.shape[1], 3)
] -= 0b10010010010010010010010010010010
submodule.qzeros.data[
:, range(2, submodule.qzeros.data.shape[1], 3)
] -= 0b01001001001001001001001001001001
elif quantize_config.bits == 4:
submodule.qzeros.data -= 0b00010001000100010001000100010001
elif quantize_config.bits == 8:
submodule.qzeros.data -= 0b00000001000000010000000100000001
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
return model
def select_quant_linear_with_pack(
bits: int,
group_size: int,
desc_act: bool,
sym: bool,
backend: Backend,
format: str,
pack: bool,
):
QuantLinear = select_quant_linear(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
backend=backend,
format=format,
pack=pack,
)
return QuantLinear
def pack_model(
model,
quantizers,
bits,
group_size,
backend: Backend,
format: str,
desc_act=False,
sym: bool = True,
use_cuda_fp16=True,
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False,
):
QuantLinear = select_quant_linear_with_pack(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=sym,
backend=backend,
format=format,
pack=True,
)
if force_layer_back_to_cpu:
model.to(CPU)
logger.info("Packing model...")
layers = find_layers(model)
layers = {n: layers[n] for n in quantizers}
make_quant(
model,
quantizers,
bits,
group_size,
backend=backend,
format=format,
use_cuda_fp16=use_cuda_fp16,
desc_act=desc_act,
pack=True,
)
qlayers = find_layers(model, [QuantLinear])
# Limit pack() thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
pbar = tqdm(qlayers.keys(), leave=True)
for name in pbar:
pbar.set_description(f"Packing {name}")
quantizers[name], scale, zero, g_idx = quantizers[name]
# so far can only pack layer on CPU
layer_device = qlayers[name].device
qlayers[name].to(CPU)
layers[name], scale, zero, g_idx = (
layers[name].to(CPU),
scale.to(CPU),
zero.to(CPU),
g_idx.to(CPU),
)
if QuantLinear.QUANT_TYPE == "marlin":
qlayers[name].pack(layers[name], scale)
else:
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)
logger.info("Model packed.")
if backend != Backend.TRITON and warmup_triton:
logger.warning(
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the whole model."
)
QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen)
return QuantLinear
def verify_model_hash(file_path: str, verify_hash: str):
if not isinstance(verify_hash, str):
raise ValueError("model verify_hash must be a string")
if ":" not in verify_hash:
raise ValueError("verify_hash must be in the format 'hash_type:hash_value'")
hash_type, hash_value = verify_hash.split(":", 1)
hash_func = getattr(hashlib, hash_type, None)
if not hash_func:
raise ValueError(f"No hash function found for type: {hash_type}")
with open(file_path, "rb") as f:
file_hash = hash_func(f.read()).hexdigest()
return file_hash == hash_value
def verify_sharded_model_hashes(jsonPath: str, verify_hash: List[str]):
if not isinstance(verify_hash, list):
raise ValueError("sharded model verify_hash must be a list")
with open(jsonPath, "r") as f:
index_data = json.load(f)
weight_map = index_data["weight_map"]
shard_files = set(weight_map.values())
if len(shard_files) != len(verify_hash):
raise ValueError("Number of shards and number of hash values do not match.")
for shard_file, expected_hash in zip(shard_files, verify_hash):
if not verify_model_hash(shard_file, expected_hash):
logger.info(f"Hash verification failed for {shard_file}")
return False
return True
def check_and_get_model_type(model_dir, trust_remote_code=False):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type
return model_type
def simple_dispatch_model(model, device_map):
# Third Party
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
if "" in device_map:
d = device_map[""]
model = model.to(torch.device(d))
model.hf_device_map = device_map
return model
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {
"cpu",
"disk",
}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
prev_hook = None
for idx, (n, d) in enumerate(cpu_offload_group):
m = get_module_by_name_suffix(model, n)
_, prev_hook = accelerate.cpu_offload_with_hook(
m, execution_device=main_device, prev_module_hook=prev_hook
)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if len(cpu_offload_group) > 1:
get_module_by_name_suffix(
model, cpu_offload_group[0][0]
)._hf_hook.prev_module_hook = prev_hook
for n, d in device_map.items():
m = get_module_by_name_suffix(model, n)
if d != "cpu":
d = torch.device(d)
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
add_hook_to_module(m, hook)
accelerate.utils.modeling.retie_parameters(model, tied_params)
model.hf_device_map = device_map
return model
# TODO: refractor. very strange post_init has to re-determine qlinear type again
# when qliear type is selected, it should auto-override the model post_init method and
# not have to go about looping over modules to match qlinear type a second time as it is
# very prone to bugs
def gptqmodel_post_init(
model, use_act_order: bool, max_input_length: Optional[int] = None
):
"""
The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state.
"""
# post init for bitblas backend.
device_to_buffers_size = {}
for _, submodule in model.named_modules():
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "bitblas":
submodule.post_init()
model_uses_exllama = False
for name, submodule in model.named_modules():
if isinstance(submodule, BaseQuantLinear) and submodule.QUANT_TYPE == "exllama":
model_uses_exllama = True
device = submodule.qweight.device
if device not in device_to_buffers_size:
device_to_buffers_size[device] = {
"max_dq_buffer_size": 1,
"max_inner_outer_dim": 1,
}
if not use_act_order:
submodule._use_act_order = False
else:
submodule._use_act_order = True
# Disable this heuristic for detecting act_order, but it could be used instead of the config.
"""
if submodule.g_idx is None:
submodule.act_order = False
elif submodule.g_idx is not None and ((submodule.g_idx == 0).all() or torch.equal(submodule.g_idx.cpu(), torch.tensor([i // submodule.group_size for i in range(submodule.g_idx.shape[0])], dtype=torch.int32))):
submodule.g_idx = None
submodule.act_order = False
else:
submodule.act_order = True
"""
device_to_buffers_size[device]["max_dq_buffer_size"] = max(
device_to_buffers_size[device]["max_dq_buffer_size"],
submodule.qweight.numel() * 8,
)
if use_act_order:
device_to_buffers_size[device]["max_inner_outer_dim"] = max(
device_to_buffers_size[device]["max_inner_outer_dim"],
submodule.infeatures,
submodule.outfeatures,
)
if model_uses_exllama:
# To be honest this is quite ugly, not proud of this.
# Third Party
from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params
device_to_buffers = {}
if use_act_order:
if max_input_length is None:
max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
else:
max_input_len = max_input_length
else:
if max_input_length is not None:
logger.info(
"Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored."
)
max_input_len = 1
for device, buffers_size in device_to_buffers_size.items():
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
device_to_buffers[device] = {
"temp_state": torch.zeros(
(max_input_len, buffers_size["max_inner_outer_dim"]),
dtype=torch.float16,
device=device,
),
"temp_dq": torch.zeros(
(1, buffers_size["max_dq_buffer_size"]),
dtype=torch.float16,
device=device,
),
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
}
# Buffers need to be persistent to avoid any bug.
model.device_to_buffers = device_to_buffers
for device, buffers in model.device_to_buffers.items():
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
# The buffers need to have been initialized first before calling make_q4.
for name, submodule in model.named_modules():
if (
isinstance(submodule, BaseQuantLinear)
and submodule.QUANT_TYPE == "exllama"
):
submodule.post_init()
# exllamav2
fixed_bytes = {}
model_uses_exllamav2 = False
for _, submodule in model.named_modules():
if (
isinstance(submodule, BaseQuantLinear)
and submodule.QUANT_TYPE == "exllamav2"
):
model_uses_exllamav2 = True
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed()
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0))
if model_uses_exllamav2:
# Local
from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors
device_tensors = {}
for device, scratch_bytes in fixed_bytes.items():
device_tensors[device] = ExLlamaV2DeviceTensors(device.index, scratch_bytes)
# have persistent buffers, otherwise we will get OOM
model.device_tensors = device_tensors
for _, submodule in model.named_modules():
if (
isinstance(submodule, BaseQuantLinear)
and submodule.QUANT_TYPE == "exllamav2"
):
device = submodule.qweight.device
submodule.post_init(temp_dq=model.device_tensors[device])
torch.cuda.empty_cache()
return model
def get_checkpoints(
model_name_or_path: str,
extensions: List[str],
possible_model_basenames: List[str],
**cached_file_kwargs,
):
"""
Retrives (and if necessary downloads from Hugging Face Hub) the model checkpoint. Sharding is supported. All the `possible_model_basenames` (e.g. `["model", "model-4bit-gptq"]`) will be explored over all `extensions` (e.g. `[".bin", ".safetensors"]`).
"""
searched_files = []
resolved_archive_file = None
true_model_basename = None
if os.path.isdir(model_name_or_path):
for ext in extensions:
for possible_model_basename in possible_model_basenames:
shard_index_name = possible_model_basename + ext + ".index.json"
searched_files.append(shard_index_name)
possible_index_file = os.path.join(model_name_or_path, shard_index_name)
if os.path.isfile(possible_index_file):
# The model is sharded over several checkpoints.
possible_model_basename = possible_index_file.replace(
ext + ".index.json", ""
)
return True, possible_index_file, possible_model_basename
else:
model_save_name = os.path.join(
model_name_or_path, possible_model_basename
)
searched_files.append(possible_model_basename + ext)
if os.path.isfile(model_save_name + ext):
resolved_archive_file = model_save_name + ext
return False, resolved_archive_file, possible_model_basename
else:
temp = None
for ext in extensions:
for possible_model_basename in possible_model_basenames:
shard_index_name = possible_model_basename + ext + ".index.json"
shard_index = cached_file(
model_name_or_path,
shard_index_name,
**cached_file_kwargs,
)
searched_files.append(shard_index_name)
if shard_index is not None:
# The model is sharded over several checkpoints.
with open(str(shard_index)) as f:
index_json = json.load(f)
# Download the shards from the index.json.
shards = list(set(index_json["weight_map"].values()))
for shard in shards:
resolved_archive_file = cached_file(
model_name_or_path,
shard,
**cached_file_kwargs,
)
return True, shard_index, possible_model_basename
else:
resolved_archive_file = cached_file(
model_name_or_path,
possible_model_basename + ext,
**cached_file_kwargs,
)
if resolved_archive_file is None:
resolved_archive_file = temp
searched_files.append(possible_model_basename + ext)
if resolved_archive_file is not None:
temp = resolved_archive_file
return False, resolved_archive_file, possible_model_basename
if resolved_archive_file is None:
raise FileNotFoundError(
f"Could not find a model in {model_name_or_path} with a name in {', '.join(searched_files)}. Please specify the argument model_basename to use a custom file name."
)
return False, resolved_archive_file, true_model_basename
# return the most stable tensor dtype for quantization while minimizing vram
def auto_dtype_from_config(
config: PretrainedConfig, quant_inference: bool = False
) -> torch.dtype:
# all the gptq inference kernels are float16 only
if quant_inference:
return torch.float16
dtype = getattr(config, "torch_dtype")
if not dtype or not isinstance(dtype, torch.dtype):
raise ValueError(
"Your model config.json does not have torch_dtype set. Please check for model "
"corruption."
)
if dtype == torch.float32:
return torch.bfloat16
elif dtype == torch.float16:
return torch.float16
else:
# up/down-cast everything else to bfloat16 if not already in bfloat16
return torch.bfloat16
# generate layer modules for moe models with experts
def get_moe_layer_modules(layer_modules: List, num_experts: int) -> List:
new_inside_layer_modules = []
for names in layer_modules:
new_inside_layer_modules.append([])
for n in names:
if EXPERT_INDEX_PLACEHOLDER in n:
for index in range(num_experts):
new_inside_layer_modules[-1].append(
n.replace(EXPERT_INDEX_PLACEHOLDER, str(index))
)
else:
new_inside_layer_modules[-1].append(n)
return new_inside_layer_modules
def replace_3d_parameters_with_module_list(
model: torch.nn.Module,
):
for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
if len(param.shape) == 3:
device = param.device
dtype = param.dtype
num, in_features, out_features = param.shape
module_list = []
for i in range(num):
linear = torch.nn.Linear(
in_features=in_features,
out_features=out_features,
device=device,
dtype=dtype,
bias=None, # FIXME: how to support bias?
)
linear.weight.data = param.data[i]
module_list.append(linear)
module_list = torch.nn.ModuleList(module_list)
# replace
delattr(module, param_name)
setattr(module, param_name, module_list)