-
Notifications
You must be signed in to change notification settings - Fork 531
Tiny optimizations for moe #1717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 24 commits
498027c
8b4c44b
db646ac
9c5e3ca
fd64748
145efc4
ed240ab
9178d57
27d9053
d39f3ec
2890f7e
dde43f6
671d1cd
0158022
cc512e5
0bd7a7f
20da055
99f5ff2
622f8ae
eb55ef1
723fea7
ec9cab1
29d4170
20a361b
d66ef25
f31b592
480057d
742521e
5cb9936
5d7e4d3
d29f4f9
8543c83
efb5214
920c1ac
b807967
06d5fc9
93f8690
18676b4
977b8ae
ac807d7
ad285aa
865e758
ae5376f
c3d3e9f
0910742
b15fa24
516bf70
630b482
54757d7
ab88987
704c1c5
ed6178c
bc48740
e5b01db
4106c12
8c57702
a6c4d38
1fddc58
09b5d01
5f76d06
d7631b1
1c6342c
6ba3ab4
6bcaa69
b1bb205
6ca1b1d
fec2f44
44d390e
14d26e8
4480029
80b704e
f954a1f
9603cf7
2e52a49
1c75b9b
4e40624
ea5a594
a5d4189
b52391a
7853d15
f670fa4
50378ba
9d1456a
ec89f0c
783120b
7b0f471
5873bb3
bb7a97a
8c719e6
fece864
b9bb8c7
7e6c876
06003aa
afa5a61
1583eb0
3704820
3a74536
348a536
0e62fa8
941b68c
47f3d50
d504c61
822ae9b
299caf5
d83a3cb
68b8e6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,20 +29,24 @@ | |
FLOAT4_E2M1_MAX = 6.0 | ||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max | ||
|
||
num_ranks = 2 | ||
|
||
test_configs = [ | ||
# { | ||
# "hidden_size": 7168, | ||
# "num_experts": 256, | ||
# "top_k": 8, | ||
# "intermediate_size": 256, | ||
# }, | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": 256, | ||
"top_k": 8, | ||
"intermediate_size": 256, | ||
}, | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": 32, | ||
"num_experts": num_experts, | ||
"top_k": 8, | ||
"intermediate_size": 2048, | ||
}, | ||
} | ||
for num_experts in [ | ||
256 // num_ranks, | ||
] | ||
] | ||
|
||
|
||
|
@@ -131,6 +135,13 @@ def bench_cutlass_fused_moe( | |
router_logits = torch.randn(m, e, dtype=otype).cuda() | ||
routing_weights, selected_experts = compute_routing(router_logits, top_k) | ||
|
||
if 1: | ||
print("HACK: mask some selected_experts") | ||
selected_experts[torch.randn(selected_experts.shape) > 1 / num_ranks] = 9999999 | ||
|
||
tune_max_num_tokens = batch_size | ||
print(f"HACK: {tune_max_num_tokens=}") | ||
|
||
flash_output = torch.zeros_like(x) | ||
|
||
quant_scales = [ | ||
|
@@ -143,6 +154,7 @@ def bench_cutlass_fused_moe( | |
] | ||
hidden_states = x | ||
hidden_states, input_sf = fp4_quantize(x, a1_gs) | ||
print(f"{hidden_states.shape=}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# Warmup | ||
for _ in range(3): | ||
|
@@ -156,7 +168,7 @@ def bench_cutlass_fused_moe( | |
quant_scales=quant_scales, | ||
input_sf=input_sf, | ||
output=flash_output, | ||
tune_max_num_tokens=16384, | ||
tune_max_num_tokens=tune_max_num_tokens, | ||
) | ||
|
||
if not skip_autotune: | ||
|
@@ -171,10 +183,20 @@ def bench_cutlass_fused_moe( | |
quant_scales=quant_scales, | ||
input_sf=input_sf, | ||
output=flash_output, | ||
tune_max_num_tokens=16384, | ||
tune_max_num_tokens=tune_max_num_tokens, | ||
) | ||
ms_list = bench_gpu_time( | ||
lambda: fused_moe.cutlass_fused_moe( | ||
|
||
counter = 0 | ||
|
||
def f(): | ||
nonlocal counter | ||
counter += 1 | ||
|
||
if counter == 10: | ||
print("hi call cudaProfilerStart") | ||
torch.cuda.cudart().cudaProfilerStart() | ||
|
||
fused_moe.cutlass_fused_moe( | ||
hidden_states, | ||
selected_experts.to(torch.int), | ||
routing_weights, | ||
|
@@ -184,8 +206,13 @@ def bench_cutlass_fused_moe( | |
quant_scales=quant_scales, | ||
input_sf=input_sf, | ||
output=flash_output, | ||
), | ||
) | ||
) | ||
|
||
if counter == 10: | ||
print("hi call cudaProfilerStop") | ||
torch.cuda.cudart().cudaProfilerStop() | ||
|
||
ms_list = bench_gpu_time(f) | ||
median_ms = np.median(ms_list) | ||
print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}") | ||
print( | ||
|
@@ -201,7 +228,7 @@ def bench_cutlass_fused_moe( | |
help="Update the config file with the new profiling results", | ||
) | ||
parser.add_argument( | ||
"--num-tokens", type=int, default=32, help="Number of tokens to profile" | ||
"--num-tokens", type=int, default=32768 * num_ranks, help="Number of tokens to profile" | ||
) | ||
parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning") | ||
args = parser.parse_args() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -865,7 +865,7 @@ void threeStepBuildExpertMapsSortFirstToken( | |
// ============================== Infer GEMM sizes ================================= | ||
// TODO Could linear search be better for small # experts | ||
template <class T> | ||
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, | ||
__device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices, | ||
int64_t const arr_length, T const target) { | ||
int64_t low = 0, high = arr_length - 1, target_location = -1; | ||
while (low <= high) { | ||
|
@@ -881,6 +881,49 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, | |
return target_location + 1; | ||
} | ||
|
||
template <class T> | ||
__device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices, int64_t const arr_length, T const target) { | ||
constexpr int ARR_LENGTH_CONST = 128; | ||
if (arr_length != ARR_LENGTH_CONST) { | ||
asm("trap;"); | ||
} | ||
Comment on lines
886
to
889
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function |
||
|
||
constexpr unsigned full_mask = 0xffffffffu; | ||
constexpr int WARP_SZ = 32; | ||
const int lane_id = threadIdx.x & (WARP_SZ - 1); | ||
|
||
int local_count = 0; | ||
#pragma unroll | ||
for (int k = 0; k < ARR_LENGTH_CONST / WARP_SZ; ++k) { | ||
const int idx = lane_id + k * WARP_SZ; | ||
T v = sorted_indices[idx]; | ||
local_count += (v < target) ? 1 : 0; | ||
} | ||
|
||
#pragma unroll | ||
for (int offset = 16; offset > 0; offset >>= 1) { | ||
local_count += __shfl_down_sync(full_mask, local_count, offset); | ||
} | ||
int total = __shfl_sync(full_mask, local_count, 0); | ||
|
||
return (int64_t)total; | ||
} | ||
|
||
template <class T> | ||
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) { | ||
return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); | ||
|
||
// return findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); | ||
|
||
// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target); | ||
// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target); | ||
// if (out_v1 != out_v2) { | ||
// printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2); | ||
// asm("trap;"); | ||
// } | ||
// return out_v1; | ||
} | ||
Comment on lines
913
to
925
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
template <class T> | ||
using sizeof_bits = cutlass::sizeof_bits< | ||
typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -447,6 +447,11 @@ def choose_one( | |
logger.debug( | ||
f"[AutoTunner]: Generated key{AutoTuner._get_cache_key(custom_op, runners[0], input_shapes, tuning_config)}" | ||
) | ||
else: | ||
# NOTE ADD | ||
logger.debug( | ||
f"[AutoTunner]: HACK ADD cache hit {custom_op=} {input_shapes=}" | ||
) | ||
return runner, tactic | ||
Comment on lines
+450
to
455
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
assert len(runners) > 0, "At least one runner is required" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -870,6 +870,12 @@ def cutlass_fused_moe( | |
if enable_pdl is None: | ||
enable_pdl = device_support_pdl(input.device) | ||
|
||
print( | ||
"hi flashinfer cutlass_fused_moe " | ||
f"{input.shape=} {input.dtype=} " | ||
f"{token_selected_experts.shape=}" | ||
) | ||
|
||
|
||
num_rows = input.shape[0] | ||
if min_latency_mode: | ||
num_rows *= fc2_expert_weights.shape[0] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,11 @@ name = "flashinfer-python" | |
description = "FlashInfer: Kernel Library for LLM Serving" | ||
requires-python = ">=3.9,<4.0" | ||
authors = [{ name = "FlashInfer team" }] | ||
license = "Apache-2.0" | ||
#license = "Apache-2.0" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
readme = "README.md" | ||
urls = { Homepage = "https://github.com/flashinfer-ai/flashinfer" } | ||
dynamic = ["dependencies", "version"] | ||
license-files = ["LICENSE", "licenses/*"] | ||
#license-files = ["LICENSE", "licenses/*"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
[build-system] | ||
requires = ["setuptools>=77", "packaging>=24"] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
if 1:
block appears to be temporary code for debugging and testing, as indicated by the "HACK" print statements. This block, including the hardcoded logic for masking experts and settingtune_max_num_tokens
, should be removed before merging.