Skip to content

Commit fa3d9d3

Browse files
committed
Test CUDA conv2D type conversion fix
1 parent 358fecc commit fa3d9d3

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

.github/workflows/menlo-build.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ jobs:
5656

5757
build-and-test:
5858
runs-on: ${{ matrix.runs-on }}
59-
needs: [create-draft-release]
6059
timeout-minutes: 270
6160
strategy:
6261
fail-fast: false
@@ -285,7 +284,7 @@ jobs:
285284
uses: actions/checkout@v3
286285
with:
287286
submodules: recursive
288-
287+
289288
- name: Replace our Makefile
290289
run: |
291290
cat menlo/Makefile | tee Makefile
@@ -635,4 +634,4 @@ jobs:
635634
upload_url: ${{ needs.create-draft-release.outputs.upload_url }}
636635
asset_path: /tmp/cudart-llama-bin-win-cu11.7-x64.tar.gz
637636
asset_name: cudart-llama-bin-win-cu11.7-x64.tar.gz
638-
asset_content_type: application/gzip
637+
asset_content_type: application/gzip

ggml/src/ggml-cuda/conv2d.cu

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ struct kernel_bounds {
1717
int64_t x_min, x_max;
1818
};
1919

20+
template<typename T>
21+
__device__ __forceinline__ float to_float(const T& val) {
22+
if constexpr (std::is_same_v<T, __half>) {
23+
return __half2float(val);
24+
} else {
25+
return val; // Assumes T is float
26+
}
27+
}
28+
2029
__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
2130
return (a > b) ? a : b;
2231
}
@@ -94,8 +103,8 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
94103
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
95104

96105
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
97-
const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
98-
acc += (input_val * kernel_val);
106+
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
107+
acc += (input_val * to_float(kernel_val));
99108
}
100109
}
101110
}

0 commit comments

Comments
 (0)