diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index deb82a43e66..30df2a6842c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -646,7 +646,6 @@ def apply(self, gm: GraphModule, node: Node) -> None: gm, node, self.config, - self.mlp_type, scale_names=self.scale_names(), ) @@ -664,7 +663,7 @@ def scale_names(self) -> List[str]: return ["input_scale", "weight_scale", "alpha"] def apply(self, gm: GraphModule, node: Node) -> None: - _insert_sharded_moe(gm, node, self.config, self.mlp_type, scale_names=self.scale_names()) + _insert_sharded_moe(gm, node, self.config, scale_names=self.scale_names()) EP_SHARDING_RULES = [