Skip to content

Commit 7fb23a5

Browse files
committed
[fix] fix param for qkv linear, gpt2fused linear; fix requirments;
1 parent 63b7db5 commit 7fb23a5

File tree

3 files changed

+1
-7
lines changed

3 files changed

+1
-7
lines changed

colossalai/shardformer/layer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"LinearWithGradAccum",
2222
"Linear1D_Col",
2323
"Linear1D_Row",
24-
"GPT2FusedLinearConv1D_Col",
24+
"GPT2FusedLinearConv",
2525
"GPT2FusedLinearConv1D_Row",
2626
"GPT2FusedLinearConv1D_Col",
2727
"DropoutForParallelInput",

colossalai/shardformer/layer/qkv_fused_linear.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,6 @@ def __init__(
832832
bias_: Optional[Parameter] = None,
833833
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
834834
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
835-
fp8_communication: bool = False,
836835
use_zbv: bool = False,
837836
):
838837
super().__init__()
@@ -846,7 +845,6 @@ def __init__(
846845
self.device = device
847846
self.split_sizes = split_sizes
848847
self.process_group = process_group
849-
self.fp8_communication = fp8_communication
850848
self.use_zbv = use_zbv
851849

852850
assert (
@@ -1246,7 +1244,6 @@ def __init__(
12461244
bias_: Optional[Parameter] = None,
12471245
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
12481246
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
1249-
fp8_communication: bool = False,
12501247
use_zbv: bool = False,
12511248
):
12521249
super().__init__()
@@ -1257,7 +1254,6 @@ def __init__(
12571254
self.seq_parallel_dim = seq_parallel_dim
12581255
self.skip_bias_add = skip_bias_add
12591256
self.device = device
1260-
self.fp8_communication = fp8_communication
12611257
self.use_zbv = use_zbv
12621258

12631259
if skip_bias_add and not bias:

requirements/requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,3 @@ fastapi
2424
uvicorn==0.29.0
2525
galore_torch
2626
diffusers==0.29.0
27-
pyramid<=1.10.7
28-
zope<=5.5.2

0 commit comments

Comments
 (0)