@@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
228228 "${multiValueArgs} " ${ARGN} )
229229
230230 foreach (_ARCH ${arg_CUDA_ARCHS} )
231- string (REPLACE "." "" _ARCH "${_ARCH} " )
232- set_gencode_flag_for_srcs(
233- SRCS ${arg_SRCS}
234- ARCH "compute_${_ARCH} "
235- CODE "sm_${_ARCH} " )
231+ # handle +PTX suffix: generate both sm and ptx codes if requested
232+ string (FIND "${_ARCH} " "+PTX" _HAS_PTX)
233+ if (NOT _HAS_PTX EQUAL -1)
234+ string (REPLACE "+PTX" "" _BASE_ARCH "${_ARCH} " )
235+ string (REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH} " )
236+ set_gencode_flag_for_srcs(
237+ SRCS ${arg_SRCS}
238+ ARCH "compute_${_STRIPPED_ARCH} "
239+ CODE "sm_${_STRIPPED_ARCH} " )
240+ set_gencode_flag_for_srcs(
241+ SRCS ${arg_SRCS}
242+ ARCH "compute_${_STRIPPED_ARCH} "
243+ CODE "compute_${_STRIPPED_ARCH} " )
244+ else ()
245+ string (REPLACE "." "" _STRIPPED_ARCH "${_ARCH} " )
246+ set_gencode_flag_for_srcs(
247+ SRCS ${arg_SRCS}
248+ ARCH "compute_${_STRIPPED_ARCH} "
249+ CODE "sm_${_STRIPPED_ARCH} " )
250+ endif ()
236251 endforeach ()
237252
238253 if (${arg_BUILD_PTX_FOR_ARCH} )
@@ -251,7 +266,10 @@ endmacro()
251266#
252267# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
253268# `<major>.<minor>[letter]` compute the "loose intersection" with the
254- # `TGT_CUDA_ARCHS` list of gencodes.
269+ # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
270+ # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
271+ # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
272+ # architecture in `SRC_CUDA_ARCHS`.
255273# The loose intersection is defined as:
256274# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
257275# where `<=` is the version comparison operator.
@@ -268,44 +286,63 @@ endmacro()
268286# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
269287# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
270288#
289+ # Example With PTX:
290+ # SRC_CUDA_ARCHS="8.0+PTX"
291+ # TGT_CUDA_ARCHS="9.0"
292+ # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
293+ # OUT_CUDA_ARCHS="8.0+PTX"
294+ #
271295function (cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
272- list (REMOVE_DUPLICATES SRC_CUDA_ARCHS)
273- set (TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS} )
296+ set (_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS} " )
297+ set (_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS} )
298+
299+ # handle +PTX suffix: separate base arch for matching, record PTX requests
300+ set (_PTX_ARCHS)
301+ foreach (_arch ${_SRC_CUDA_ARCHS} )
302+ if (_arch MATCHES "\\ +PTX$" )
303+ string (REPLACE "+PTX" "" _base "${_arch} " )
304+ list (APPEND _PTX_ARCHS "${_base} " )
305+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch} " )
306+ list (APPEND _SRC_CUDA_ARCHS "${_base} " )
307+ endif ()
308+ endforeach ()
309+ list (REMOVE_DUPLICATES _PTX_ARCHS)
310+ list (REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
274311
275312 # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
276313 # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
277314 set (_CUDA_ARCHS)
278- if ("9.0a" IN_LIST SRC_CUDA_ARCHS )
279- list (REMOVE_ITEM SRC_CUDA_ARCHS "9.0a" )
280- if ("9.0" IN_LIST TGT_CUDA_ARCHS_ )
281- list (REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0" )
315+ if ("9.0a" IN_LIST _SRC_CUDA_ARCHS )
316+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a" )
317+ if ("9.0" IN_LIST TGT_CUDA_ARCHS )
318+ list (REMOVE_ITEM _TGT_CUDA_ARCHS "9.0" )
282319 set (_CUDA_ARCHS "9.0a" )
283320 endif ()
284321 endif ()
285322
286- if ("10.0a" IN_LIST SRC_CUDA_ARCHS )
287- list (REMOVE_ITEM SRC_CUDA_ARCHS "10.0a" )
323+ if ("10.0a" IN_LIST _SRC_CUDA_ARCHS )
324+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a" )
288325 if ("10.0" IN_LIST TGT_CUDA_ARCHS)
289- list (REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0" )
326+ list (REMOVE_ITEM _TGT_CUDA_ARCHS "10.0" )
290327 set (_CUDA_ARCHS "10.0a" )
291328 endif ()
292329 endif ()
293330
294- list (SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
331+ list (SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
295332
296333 # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
297334 # is less or equal to ARCH (but has the same major version since SASS binary
298335 # compatibility is only forward compatible within the same major version).
299- foreach (_ARCH ${TGT_CUDA_ARCHS_ } )
336+ foreach (_ARCH ${_TGT_CUDA_ARCHS } )
300337 set (_TMP_ARCH)
301338 # Extract the major version of the target arch
302339 string (REGEX REPLACE "^([0-9]+)\\ ..*$" "\\ 1" TGT_ARCH_MAJOR "${_ARCH} " )
303- foreach (_SRC_ARCH ${SRC_CUDA_ARCHS } )
340+ foreach (_SRC_ARCH ${_SRC_CUDA_ARCHS } )
304341 # Extract the major version of the source arch
305342 string (REGEX REPLACE "^([0-9]+)\\ ..*$" "\\ 1" SRC_ARCH_MAJOR "${_SRC_ARCH} " )
306- # Check major- version match AND version -less-or-equal
343+ # Check version-less-or-equal, and allow PTX arches to match across majors
307344 if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
308- if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
345+ if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
309346 set (_TMP_ARCH "${_SRC_ARCH} " )
310347 endif ()
311348 else ()
@@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
321358 endforeach ()
322359
323360 list (REMOVE_DUPLICATES _CUDA_ARCHS)
361+
362+ # reapply +PTX suffix to architectures that requested PTX
363+ set (_FINAL_ARCHS)
364+ foreach (_arch ${_CUDA_ARCHS} )
365+ if (_arch IN_LIST _PTX_ARCHS)
366+ list (APPEND _FINAL_ARCHS "${_arch} +PTX" )
367+ else ()
368+ list (APPEND _FINAL_ARCHS "${_arch} " )
369+ endif ()
370+ endforeach ()
371+ set (_CUDA_ARCHS ${_FINAL_ARCHS} )
372+
324373 set (${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
325374endfunction ()
326375
0 commit comments