@@ -75,11 +75,11 @@ def create_quantized_state_dict(self):
75
75
cur_state_dict [f"{ fqn } .weight" ] = int8_weight
76
76
cur_state_dict [f"{ fqn } .scales" ] = scales .to (mod .weight .dtype )
77
77
elif isinstance (mod , ConditionalFeedForward ):
78
- num_experts , intermediate_size , dim = mod .w1 .shape
79
78
for weight_idx in range (0 , 3 ):
80
79
weight_name = f"w{ weight_idx + 1 } "
81
80
scales_name = f"scales{ weight_idx + 1 } "
82
81
weight = getattr (mod , weight_name )
82
+ num_experts , intermediate_size , dim = weight .shape
83
83
84
84
bit8_weight_list = []
85
85
scales_list = []
@@ -125,20 +125,20 @@ def __init__(self, num_experts, intermediate_size, dim, target_dtype):
125
125
self .target_dtype = target_dtype
126
126
127
127
self .register_buffer ("w1" , torch .empty (num_experts , intermediate_size , dim , dtype = target_dtype ))
128
- self .register_buffer ("w2" , torch .empty (num_experts , intermediate_size , dim , dtype = target_dtype ))
128
+ self .register_buffer ("w2" , torch .empty (num_experts , dim , intermediate_size , dtype = target_dtype ))
129
129
self .register_buffer ("w3" , torch .empty (num_experts , intermediate_size , dim , dtype = target_dtype ))
130
130
131
131
self .register_buffer ("scales1" , torch .empty (num_experts , intermediate_size , dtype = torch .bfloat16 ))
132
- self .register_buffer ("scales2" , torch .empty (num_experts , intermediate_size , dtype = torch .bfloat16 ))
132
+ self .register_buffer ("scales2" , torch .empty (num_experts , dim , dtype = torch .bfloat16 ))
133
133
self .register_buffer ("scales3" , torch .empty (num_experts , intermediate_size , dtype = torch .bfloat16 ))
134
134
135
135
def forward (self , x , expert_indices ):
136
- w1_weights = ( self .w1 .to (x .dtype )[expert_indices ] * self . scales1 [ expert_indices ]. to ( x . dtype ). unsqueeze ( - 1 )). transpose ( - 1 , - 2 ) # [T, A, D, D]
137
- w3_weights = ( self .w3 .to (x .dtype )[expert_indices ] * self . scales3 [ expert_indices ]. to ( x . dtype ). unsqueeze ( - 1 )). transpose ( - 1 , - 2 ) # [T, A, D, D]
138
- w2_weights = ( self .w2 .to (x .dtype )[expert_indices ] * self . scales2 [ expert_indices ]. to ( x . dtype ). unsqueeze ( - 1 )) # [T, A, D, D ]
139
- x1 = F .silu (torch .einsum ('ti,taio -> tao' , x , w1_weights ))
140
- x3 = torch .einsum ('ti, taio -> tao' , x , w3_weights )
141
- expert_outs = torch .einsum ('tao, taoi -> tai' , (x1 * x3 ), w2_weights )
136
+ w1_weights = self .w1 .to (x .dtype )[expert_indices ] # [T, A, D, D]
137
+ w3_weights = self .w3 .to (x .dtype )[expert_indices ] # [T, A, D, D]
138
+ w2_weights = self .w2 .to (x .dtype )[expert_indices ]
139
+ x1 = F .silu (torch .einsum ('ti,taoi -> tao' , x , w1_weights ) * self . scales1 [ expert_indices ]. to ( x . dtype ))
140
+ x3 = torch .einsum ('ti, taoi -> tao' , x , w3_weights ) * self . scales3 [ expert_indices ]. to ( x . dtype )
141
+ expert_outs = torch .einsum ('tao, taio -> tai' , (x1 * x3 ), w2_weights ) * self . scales2 [ expert_indices ]. to ( x . dtype ) # [T, A, D, D]
142
142
return expert_outs
143
143
144
144
0 commit comments