Skip to content

Commit 68f3d68

Browse files
authored
Fix trtllm_ar failure (#1423)
Fix `TypeError: cannot be converted to pointer` error in AR+RMSNorm. This was broken by 85d75ca <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Signed-off-by: Po-Han Huang <[email protected]>
1 parent ade3885 commit 68f3d68

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flashinfer/comm/trtllm_ar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import functools
1818
import logging
19-
from ctypes import c_void_p
19+
from ctypes import c_void_p, cast
2020
from types import SimpleNamespace
2121
from typing import List, Optional, Tuple, Union
2222

@@ -610,7 +610,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
610610
# Set flag_ptr[3] = lamport_comm_size
611611
lamport_comm_size_bytes = lamport_comm_size.to_bytes(4, byteorder="little")
612612
cudart.cudaMemcpy(
613-
c_void_p(flag_ptr.value + 3 * 4), c_void_p(lamport_comm_size_bytes), 4
613+
c_void_p(flag_ptr.value + 3 * 4), cast(lamport_comm_size_bytes, c_void_p), 4
614614
)
615615
print("set flag_ptr[3] = lamport_comm_size: ", lamport_comm_size)
616616
# add flag_ptr to workspace

0 commit comments

Comments
 (0)