Skip to content

Commit e02d44e

Browse files
bnellnmmgoin
authored andcommitted
[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (vllm-project#22035)
Signed-off-by: Bill Nell <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent d1331a4 commit e02d44e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2022
-1305
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ steps:
399399
- label: Kernels MoE Test %N
400400
mirror_hardwares: [amdexperimental]
401401
source_file_dependencies:
402+
- csrc/quantization/cutlass_w8a8/moe/
402403
- csrc/moe/
403404
- tests/kernels/moe
404405
- vllm/model_executor/layers/fused_moe/

docs/design/fused_moe_modular_kernel.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
175175

176176
### FusedMoEModularKernel Initialization
177177

178-
`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
178+
`FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
179179

180+
* maybe_make_prepare_finalize,
180181
* select_gemm_impl, and
181182
* init_prepare_finalize
182183

184+
#### maybe_make_prepare_finalize
185+
186+
The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
187+
Please refer to the implementations in,
188+
189+
* `ModelOptNvFp4FusedMoE`
190+
183191
#### select_gemm_impl
184192

185193
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.

examples/offline_inference/data_parallel.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,27 @@ def parse_args():
7070
default=64,
7171
help=("Maximum number of sequences to be processed in a single iteration."),
7272
)
73+
parser.add_argument(
74+
"--max-model-len",
75+
type=int,
76+
help=("Maximum number of tokens to be processed in a single iteration."),
77+
)
78+
parser.add_argument(
79+
"--timeout",
80+
type=int,
81+
default=300,
82+
help=("Number of seconds before unresponsive process is killed."),
83+
)
7384
parser.add_argument(
7485
"--gpu-memory-utilization",
7586
type=float,
7687
default=0.8,
7788
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
7889
)
90+
parser.add_argument(
91+
"--quantization",
92+
type=str,
93+
)
7994
return parser.parse_args()
8095

8196

@@ -90,7 +105,9 @@ def main(
90105
enforce_eager,
91106
trust_remote_code,
92107
max_num_seqs,
108+
max_model_len,
93109
gpu_memory_utilization,
110+
quantization,
94111
):
95112
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
96113
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
@@ -142,7 +159,9 @@ def start(rank):
142159
enable_expert_parallel=True,
143160
trust_remote_code=trust_remote_code,
144161
max_num_seqs=max_num_seqs,
162+
max_model_len=max_model_len,
145163
gpu_memory_utilization=gpu_memory_utilization,
164+
quantization=quantization,
146165
)
147166
outputs = llm.generate(prompts, sampling_params)
148167
# Print the outputs.
@@ -198,14 +217,16 @@ def start(rank):
198217
args.enforce_eager,
199218
args.trust_remote_code,
200219
args.max_num_seqs,
220+
args.max_model_len,
201221
args.gpu_memory_utilization,
222+
args.quantization,
202223
),
203224
)
204225
proc.start()
205226
procs.append(proc)
206227
exit_code = 0
207228
for proc in procs:
208-
proc.join(timeout=300)
229+
proc.join(timeout=args.timeout)
209230
if proc.exitcode is None:
210231
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
211232
proc.kill()

0 commit comments

Comments
 (0)