Skip to content

Commit 6628efa

Browse files
authored
Merge pull request #450 from NVIDIA/cublas_fix
2 parents 0a3258d + 62fe2bf commit 6628efa

File tree

4 files changed

+115
-26
lines changed

4 files changed

+115
-26
lines changed

WORKSPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ new_local_repository(
3838
path = "/usr/local/cuda-11.1/",
3939
)
4040

41+
new_local_repository(
42+
name = "cublas",
43+
build_file = "@//third_party/cublas:BUILD",
44+
path = "/usr",
45+
)
4146
#############################################################################################################
4247
# Tarballs and fetched dependencies (default - use in cases when building from precompiled bin and tarballs)
4348
#############################################################################################################

third_party/cublas/BUILD

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
# NOTE: This BUILD file is only really targeted at aarch64, the rest of the configuration is just to satisfy bazel, x86 uses the cublas source from the CUDA build file since it will be versioned with CUDA.
4+
5+
config_setting(
6+
name = "aarch64_linux",
7+
constraint_values = [
8+
"@platforms//cpu:aarch64",
9+
"@platforms//os:linux",
10+
],
11+
)
12+
13+
config_setting(
14+
name = "windows",
15+
constraint_values = [
16+
"@platforms//os:windows",
17+
],
18+
)
19+
20+
cc_library(
21+
name = "cublas_headers",
22+
hdrs = select({
23+
":aarch64_linux": ["include/cublas.h"] + glob(["usr/include/cublas+.h"]),
24+
"//conditions:default": ["local/cuda/include/cublas.h"] + glob(["usr/cuda/include/cublas+.h"]),
25+
}),
26+
includes = ["include/"],
27+
visibility = ["//visibility:private"],
28+
)
29+
30+
cc_import(
31+
name = "cublas_lib",
32+
shared_library = select({
33+
":aarch64_linux": "lib/aarch64-linux-gnu/libcublas.so",
34+
":windows": "lib/x64/cublas.lib",
35+
"//conditions:default": "local/cuda/targets/x86_64-linux/lib/libcublas.so",
36+
}),
37+
visibility = ["//visibility:private"],
38+
)
39+
40+
cc_import(
41+
name = "cublas_lt_lib",
42+
shared_library = select({
43+
":aarch64_linux": "lib/aarch64-linux-gnu/libcublasLt.so",
44+
"//conditions:default": "local/cuda/targets/x86_64-linux/lib/libcublasLt.so",
45+
}),
46+
visibility = ["//visibility:private"],
47+
)
48+
49+
cc_library(
50+
name = "cublas",
51+
visibility = ["//visibility:public"],
52+
deps = [
53+
"cublas_headers",
54+
"cublas_lib",
55+
"cublas_lt_lib",
56+
],
57+
)

third_party/tensorrt/archive/BUILD

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
package(default_visibility = ["//visibility:public"])
22

3+
config_setting(
4+
name = "aarch64_linux",
5+
constraint_values = [
6+
"@platforms//cpu:aarch64",
7+
"@platforms//os:linux",
8+
],
9+
)
10+
11+
config_setting(
12+
name = "windows",
13+
constraint_values = [
14+
"@platforms//os:windows",
15+
],
16+
)
17+
318
cc_library(
419
name = "nvinfer_headers",
520
hdrs = glob(
@@ -27,10 +42,13 @@ cc_library(
2742
deps = [
2843
"nvinfer_headers",
2944
"nvinfer_lib",
30-
"@cuda//:cublas",
3145
"@cuda//:cudart",
3246
"@cudnn",
33-
],
47+
] + select({
48+
":aarch64_linux": ["@cublas//:cublas"],
49+
":windows": ["@cuda//:cublas"],
50+
"//conditions:default": ["@cuda//:cublas"],
51+
}),
3452
)
3553

3654
####################################################################################
@@ -165,5 +183,11 @@ cc_library(
165183
"nvinfer",
166184
"nvinferplugin_headers",
167185
"nvinferplugin_lib",
168-
],
186+
"@cuda//:cudart",
187+
"@cudnn",
188+
] + select({
189+
":aarch64_linux": ["@cublas//:cublas"],
190+
":windows": ["@cuda//:cublas"],
191+
"//conditions:default": ["@cuda//:cublas"],
192+
}),
169193
)

third_party/tensorrt/local/BUILD

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,13 @@ cc_library(
8686
deps = [
8787
"nvinfer_headers",
8888
"nvinfer_lib",
89-
"@cuda//:cublas",
9089
"@cuda//:cudart",
9190
"@cudnn",
92-
],
91+
] + select({
92+
":aarch64_linux": ["@cublas//:cublas"],
93+
":windows": ["@cuda//:cublas"],
94+
"//conditions:default": ["@cuda//:cublas"],
95+
}),
9396
)
9497

9598
####################################################################################
@@ -274,47 +277,47 @@ cc_library(
274277

275278
cc_library(
276279
name = "nvcaffeparser",
277-
visibility = ["//visibility:public"],
278280
deps = [
279281
"nvcaffeparser_headers",
280282
"nvcaffeparser_lib",
281283
"nvinfer",
282284
],
285+
visibility = ["//visibility:public"],
283286
)
284287

285288
####################################################################################
286289

287-
cc_import(
288-
name = "nvinferplugin_lib",
289-
shared_library = select({
290-
":aarch64_linux": "lib/x86_64-linux-gnu/libnvinfer_plugin.so",
291-
":windows": "lib/nvinfer_plugin.dll",
292-
"//conditions:default": "lib/x86_64-linux-gnu/libnvinfer_plugin.so",
293-
}),
294-
visibility = ["//visibility:private"],
295-
)
296-
297290
cc_library(
298-
name = "nvinferplugin_headers",
291+
name = "nvinferplugin",
299292
hdrs = select({
300293
":aarch64_linux": glob(["include/aarch64-linux-gnu/NvInferPlugin*.h"]),
301294
":windows": glob(["include/NvInferPlugin*.h"]),
302295
"//conditions:default": glob(["include/x86_64-linux-gnu/NvInferPlugin*.h"]),
303296
}),
297+
srcs = select({
298+
":aarch64_linux": ["lib/aarch64-linux-gnu/libnvinfer_plugin.so"],
299+
":windows": ["lib/nvinfer_plugin.dll"],
300+
"//conditions:default": ["lib/x86_64-linux-gnu/libnvinfer_plugin.so"],
301+
}),
304302
includes = select({
305-
":aarch64_linux": ["include/aarch64-linux-gnu"],
303+
":aarch64_linux": ["include/aarch64-linux-gnu/"],
306304
":windows": ["include/"],
307305
"//conditions:default": ["include/x86_64-linux-gnu/"],
308306
}),
309-
visibility = ["//visibility:private"],
310-
)
311-
312-
cc_library(
313-
name = "nvinferplugin",
314-
visibility = ["//visibility:public"],
315307
deps = [
316308
"nvinfer",
317-
"nvinferplugin_headers",
318-
"nvinferplugin_lib",
309+
"@cuda//:cudart",
310+
"@cudnn",
311+
] + select({
312+
":aarch64_linux": ["@cublas//:cublas"],
313+
":windows": ["@cuda//:cublas"],
314+
"//conditions:default": ["@cuda//:cublas"],
315+
}),
316+
alwayslink = True,
317+
copts = [
318+
"-pthread"
319319
],
320+
linkopts = [
321+
"-lpthread",
322+
]
320323
)

0 commit comments

Comments
 (0)