@@ -1899,6 +1899,15 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
1899
1899
staged_hidden_states .copy_ (hidden_states , non_blocking = True )
1900
1900
staged_router_logits .copy_ (router_logits , non_blocking = True )
1901
1901
1902
+ # If there are shared experts but we are not using a modular kernel,
1903
+ # the shared experts must be called here
1904
+ if (not isinstance (self .quant_method .fused_experts ,
1905
+ FusedMoEModularKernel )
1906
+ and self .shared_experts is not None ):
1907
+ shared_output = self .shared_experts (staged_hidden_states )
1908
+ else :
1909
+ shared_output = None
1910
+
1902
1911
# Matrix multiply.
1903
1912
final_hidden_states = self .quant_method .apply (
1904
1913
layer = self ,
@@ -1922,8 +1931,13 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
1922
1931
logical_replica_count = self .logical_replica_count ,
1923
1932
)
1924
1933
1925
- assert self .shared_experts is None or isinstance (
1926
- final_hidden_states , tuple )
1934
+ if shared_output is not None :
1935
+ assert not isinstance (final_hidden_states , tuple )
1936
+ assert self .shared_experts is not None
1937
+ final_hidden_states = (
1938
+ shared_output ,
1939
+ final_hidden_states ,
1940
+ )
1927
1941
1928
1942
if self .zero_expert_num is not None and self .zero_expert_num > 0 :
1929
1943
assert isinstance (final_hidden_states , tuple )
0 commit comments