@@ -266,10 +266,6 @@ def __init__(
266
266
)
267
267
268
268
self .hidden_size = fd_config .model_config .hidden_size
269
- self .weight_shape = [
270
- self .input_size ,
271
- self .output_size ,
272
- ]
273
269
274
270
assert self .quant_method is not None
275
271
self .quant_method .create_weights (
@@ -311,24 +307,21 @@ def __init__(
311
307
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
312
308
skip_quant (bool): Whether to skip quantization. Defaults to False.
313
309
"""
310
+ self .fd_config = fd_config
311
+ self .nranks = fd_config .parallel_config .tensor_parallel_size
312
+ self .input_size = input_size
313
+ self .output_size = divide (output_size , self .nranks ) # Split the output_size using TP inference.
314
+ self .hidden_size = fd_config .model_config .hidden_size
315
+
314
316
super ().__init__ (
315
317
fd_config = fd_config ,
316
318
prefix = prefix ,
317
- input_size = input_size ,
318
- output_size = output_size ,
319
+ input_size = self . input_size ,
320
+ output_size = self . output_size ,
319
321
with_bias = with_bias ,
320
322
add_bias = add_bias ,
321
323
skip_quant = skip_quant ,
322
324
)
323
- self .fd_config = fd_config
324
- self .nranks = fd_config .parallel_config .tensor_parallel_size
325
- self .input_size = input_size
326
- self .output_size = divide (output_size , self .nranks ) # Split the output_size using TP inference.
327
- self .hidden_size = fd_config .model_config .hidden_size
328
- self .weight_shape = [
329
- self .input_size ,
330
- self .output_size ,
331
- ]
332
325
333
326
assert self .quant_method is not None
334
327
self .quant_method .create_weights (
@@ -634,15 +627,6 @@ def __init__(
634
627
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
635
628
skip_quant (bool): Whether to skip quantization. Defaults to False.
636
629
"""
637
- super ().__init__ (
638
- fd_config = fd_config ,
639
- prefix = prefix ,
640
- input_size = input_size ,
641
- output_size = output_size ,
642
- with_bias = with_bias ,
643
- add_bias = add_bias ,
644
- skip_quant = skip_quant ,
645
- )
646
630
self .fd_config = fd_config
647
631
self .skip_quant = False
648
632
self .nranks = fd_config .parallel_config .tensor_parallel_size
@@ -654,11 +638,15 @@ def __init__(
654
638
self .input_size = divide (input_size , self .nranks )
655
639
self .output_size = output_size
656
640
657
- self .weight_shape = [
658
- self .input_size ,
659
- self .output_size ,
660
- ]
661
- self ._dtype = self ._helper .get_default_dtype ()
641
+ super ().__init__ (
642
+ fd_config = fd_config ,
643
+ prefix = prefix ,
644
+ input_size = self .input_size ,
645
+ output_size = self .output_size ,
646
+ with_bias = with_bias ,
647
+ add_bias = add_bias ,
648
+ skip_quant = skip_quant ,
649
+ )
662
650
663
651
assert self .quant_method is not None
664
652
self .quant_method .create_weights (
0 commit comments