Skip to content

Commit a89d5e9

Browse files
msaroufimpytorchmergebot
authored andcommitted
compile_kernel remove header_code arg (pytorch#163165)
We previously asked users to seperate these because we didn't have any way of adding extern C declarations. Now we don't and we don't need this confusing flag anymore BC breaking but is fine for this API since it doesn't have major users yet. Please just put your all your code in `kernel_source` moving forward ## BC note The header_code parameter has been removed from torch.cuda._compile_kernel. Previously, users could pass separate header code that would be prepended to the kernel source. Now, header code must be included directly in the kernel_source parameter. Note this only affects torch.cuda._compile_kernel, which is a private API. Example: Before ```python kernel = compile_kernel( kernel_source="global void my_kernel() { ... }", kernel_name="my_kernel", header_code="#define SCALE 2.0f\n__device_ float scale(float x) { return x * SCALE; }" ) ``` After ```python kernel_source = """ #define SCALE 2.0f device float scale(float x) { return x * SCALE; } global void my_kernel() { ... } """ kernel = _compile_kernel(kernel_source, "my_kernel") ``` Pull Request resolved: pytorch#163165 Approved by: https://github.com/janeyx99, https://github.com/albanD
1 parent 4660e38 commit a89d5e9

File tree

3 files changed

+3
-18
lines changed

3 files changed

+3
-18
lines changed

test/test_cuda.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6848,25 +6848,21 @@ def test_compile_kernel(self):
68486848
self.assertEqual(c_int, expected_int)
68496849

68506850
# Test with header code
6851-
header_code = """
6851+
scale_kernel_source = """
68526852
#define SCALE_FACTOR 2.0f
68536853
68546854
__device__ float scale_value(float val) {
68556855
return val * SCALE_FACTOR;
68566856
}
6857-
"""
68586857
6859-
scale_kernel_source = """
68606858
__global__ void scale_tensors(const float* input, float* output, int n) {
68616859
int i = threadIdx.x + blockIdx.x * blockDim.x;
68626860
if (i < n)
68636861
output[i] = scale_value(input[i]);
68646862
}
68656863
"""
68666864

6867-
scale_kernel = _compile_kernel(
6868-
scale_kernel_source, "scale_tensors", header_code=header_code
6869-
)
6865+
scale_kernel = _compile_kernel(scale_kernel_source, "scale_tensors")
68706866

68716867
input_tensor = torch.rand(N, device="cuda")
68726868
output_tensor = torch.empty_like(input_tensor)

torch/cuda/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,7 +1733,6 @@ def _compile_kernel(
17331733
kernel_source: str,
17341734
kernel_name: str,
17351735
compute_capability: Optional[str] = None,
1736-
header_code: str = "",
17371736
cuda_include_dirs: Optional[list] = None,
17381737
nvcc_options: Optional[list] = None,
17391738
):
@@ -1750,7 +1749,6 @@ def _compile_kernel(
17501749
kernel_name (str): The name of the kernel function to compile
17511750
compute_capability (str, optional): The compute capability to target (e.g., "86").
17521751
If None, will detect from current device.
1753-
header_code (str, optional): Additional header code to prepend to the kernel source
17541752
cuda_include_dirs (list, optional): List of directories containing CUDA headers
17551753
nvcc_options (list, optional): Additional options to pass to NVRTC
17561754
@@ -1780,7 +1778,6 @@ def _compile_kernel(
17801778
kernel_source,
17811779
kernel_name,
17821780
compute_capability,
1783-
header_code,
17841781
cuda_include_dirs,
17851782
nvcc_options,
17861783
)

torch/cuda/_utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def _nvrtc_compile(
114114
kernel_source: str,
115115
kernel_name: str,
116116
compute_capability: Optional[str] = None,
117-
header_code: str = "",
118117
cuda_include_dirs: Optional[list] = None,
119118
nvcc_options: Optional[list] = None,
120119
auto_pch: bool = False,
@@ -127,7 +126,6 @@ def _nvrtc_compile(
127126
kernel_name (str): The name of the kernel function to compile
128127
compute_capability (str, None): The compute capability to target (e.g., "86").
129128
If None, will detect from current device.
130-
header_code (str, optional): Additional header code to prepend to the kernel source
131129
cuda_include_dirs (list, None): List of directories containing CUDA headers
132130
nvcc_options (list, None): Additional options to pass to NVRTC
133131
auto_pch (bool): Enable automatic precompiled headers (CUDA 12.8+)
@@ -156,14 +154,8 @@ def check_nvrtc(result: int) -> None:
156154
)
157155
raise RuntimeError(f"CUDA error: {error_message}")
158156

159-
# Combine header code and kernel source
160-
if header_code:
161-
full_source = header_code + "\n" + kernel_source
162-
else:
163-
full_source = kernel_source
164-
165157
# Convert source to bytes
166-
source_bytes = full_source.encode("utf-8")
158+
source_bytes = kernel_source.encode("utf-8")
167159

168160
# Get compute capability if not provided
169161
if compute_capability is None:

0 commit comments

Comments
 (0)