@@ -134,23 +134,49 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
134
134
selected_experts , num_classes = self .num_experts
135
135
).permute (2 , 1 , 0 )
136
136
137
- # Loop over all available experts in the model and perform the computation on each expert
138
- for expert_idx in range (self .num_experts ):
139
- expert_layer = self .experts [expert_idx ]
140
- idx , top_x = torch .where (expert_mask [expert_idx ])
141
- # Index the correct hidden states and compute the expert hidden state for
142
- # the current expert. We need to make sure to multiply the output hidden
143
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
144
- current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
145
- current_hidden_states = (
146
- expert_layer (current_state ) * routing_weights [top_x , idx , None ]
147
- )
148
-
149
- # However `index_add_` only support torch tensors for indexing so we'll use
150
- # the `top_x` tensor here.
151
- final_hidden_states .index_add_ (
152
- 0 , top_x , current_hidden_states .to (hidden_states .dtype )
153
- )
137
+ # Separate paths for training (with .nonzero()) and inference (without .nonzero())
138
+ if self .training :
139
+ # Training path: use .nonzero() for efficiency (skip non-selected experts)
140
+ expert_hit = torch .greater (expert_mask .sum (dim = (- 1 , - 2 )), 0 ).nonzero ()
141
+ for expert_idx in expert_hit :
142
+ expert_layer = self .experts [expert_idx ]
143
+ idx , top_x = torch .where (expert_mask [expert_idx ].squeeze (0 ))
144
+ # Index the correct hidden states and compute the expert hidden state for
145
+ # the current expert. We need to make sure to multiply the output hidden
146
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
147
+ current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
148
+ current_hidden_states = (
149
+ expert_layer (current_state ) * routing_weights [top_x , idx , None ]
150
+ )
151
+
152
+ # However `index_add_` only support torch tensors for indexing so we'll use
153
+ # the `top_x` tensor here.
154
+ final_hidden_states .index_add_ (
155
+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
156
+ )
157
+ else :
158
+ # Inference path: loop over all experts for torch.export compatibility
159
+ for expert_idx in range (self .num_experts ):
160
+ expert_layer = self .experts [expert_idx ]
161
+ idx , top_x = torch .where (expert_mask [expert_idx ])
162
+
163
+ # Skip if no tokens are assigned to this expert
164
+ if top_x .shape [0 ] == 0 :
165
+ continue
166
+
167
+ # Index the correct hidden states and compute the expert hidden state for
168
+ # the current expert. We need to make sure to multiply the output hidden
169
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
170
+ current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
171
+ current_hidden_states = (
172
+ expert_layer (current_state ) * routing_weights [top_x , idx , None ]
173
+ )
174
+
175
+ # However `index_add_` only support torch tensors for indexing so we'll use
176
+ # the `top_x` tensor here.
177
+ final_hidden_states .index_add_ (
178
+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
179
+ )
154
180
final_hidden_states = final_hidden_states .reshape (
155
181
batch_size , sequence_length , hidden_dim
156
182
)
0 commit comments