Skip to content

Commit 3720d8b

Browse files
authored
[flang][cuda] Update some bind name to fast version and add __sincosf (#153744)
Use the fast version in the bind name and reorder these fast math functions. Add missing __sincosf interface.
1 parent ed6d505 commit 3720d8b

File tree

3 files changed

+70
-54
lines changed

3 files changed

+70
-54
lines changed

flang/module/cudadevice.f90

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -394,27 +394,77 @@ attributes(device) subroutine sincospi(x, y, z) bind(c,name='__nv_sincospi')
394394
end interface
395395

396396
interface
397-
attributes(device) real(4) function __cosf(x) bind(c, name='__nv_cosf')
397+
attributes(device) real(4) function __cosf(x) bind(c, name='__nv_fast_cosf')
398398
real(4), value :: x
399399
end function
400400
end interface
401401

402+
interface __exp10f
403+
attributes(device) real function __exp10f(r) bind(c, name='__nv_fast_exp10f')
404+
!dir$ ignore_tkr (d) r
405+
real, value :: r
406+
end function
407+
end interface
408+
409+
interface __expf
410+
attributes(device) real function __expf(r) bind(c, name='__nv_fast_expf')
411+
!dir$ ignore_tkr (d) r
412+
real, value :: r
413+
end function
414+
end interface
415+
402416
interface __fdividef
403417
attributes(device) real function __fdividef(r,d) bind(c, name='__nv_fast_fdividef')
404418
!dir$ ignore_tkr (d) r, (d) d
405419
real, value :: r,d
406420
end function
407421
end interface
408422

423+
interface __log10f
424+
attributes(device) real function __log10f(r) bind(c, name='__nv_fast_log10f')
425+
!dir$ ignore_tkr (d) r
426+
real, value :: r
427+
end function
428+
end interface
429+
430+
interface __log2f
431+
attributes(device) real function __log2f(r) bind(c, name='__nv_fast_log2f')
432+
!dir$ ignore_tkr (d) r
433+
real, value :: r
434+
end function
435+
end interface
436+
437+
interface __logf
438+
attributes(device) real function __logf(r) bind(c, name='__nv_fast_logf')
439+
!dir$ ignore_tkr (d) r
440+
real, value :: r
441+
end function
442+
end interface
443+
444+
interface
445+
attributes(device) real(4) function __powf(x,y) bind(c, name='__nv_fast_powf')
446+
!dir$ ignore_tkr (d) x, y
447+
real(4), value :: x, y
448+
end function
449+
end interface
450+
451+
interface __sincosf
452+
attributes(device) subroutine __sincosf(r, s, c) bind(c, name='__nv_fast_sincosf')
453+
!dir$ ignore_tkr (d) r, (d) s, (d) c
454+
real, value :: r
455+
real :: s, c
456+
end subroutine
457+
end interface
458+
409459
interface __sinf
410-
attributes(device) real function __sinf(r) bind(c, name='__nv_sinf')
460+
attributes(device) real function __sinf(r) bind(c, name='__nv_fast_sinf')
411461
!dir$ ignore_tkr (d) r
412462
real, value :: r
413463
end function
414464
end interface
415465

416466
interface __tanf
417-
attributes(device) real function __tanf(r) bind(c, name='__nv_tanf')
467+
attributes(device) real function __tanf(r) bind(c, name='__nv_fast_tanf')
418468
!dir$ ignore_tkr (d) r
419469
real, value :: r
420470
end function
@@ -1078,13 +1128,6 @@ attributes(device) real(8) function sinpi(x) bind(c,name='__nv_sinpi')
10781128
end function
10791129
end interface
10801130

1081-
interface
1082-
attributes(device) real(4) function __powf(x,y) bind(c, name='__nv_powf')
1083-
!dir$ ignore_tkr (d) x, y
1084-
real(4), value :: x, y
1085-
end function
1086-
end interface
1087-
10881131
interface __brev
10891132
attributes(device) integer function __brev(i) bind(c, name='__nv_brev')
10901133
!dir$ ignore_tkr (d) i
@@ -1944,41 +1987,6 @@ attributes(device,host) logical function on_device() bind(c)
19441987
end function
19451988
end interface
19461989

1947-
interface __log2f
1948-
attributes(device) real function __log2f(r) bind(c, name='__nv_log2f')
1949-
!dir$ ignore_tkr (d) r
1950-
real, value :: r
1951-
end function
1952-
end interface
1953-
1954-
interface __log10f
1955-
attributes(device) real function __log10f(r) bind(c, name='__nv_log10f')
1956-
!dir$ ignore_tkr (d) r
1957-
real, value :: r
1958-
end function
1959-
end interface
1960-
1961-
interface __logf
1962-
attributes(device) real function __logf(r) bind(c, name='__nv_logf')
1963-
!dir$ ignore_tkr (d) r
1964-
real, value :: r
1965-
end function
1966-
end interface
1967-
1968-
interface __expf
1969-
attributes(device) real function __expf(r) bind(c, name='__nv_expf')
1970-
!dir$ ignore_tkr (d) r
1971-
real, value :: r
1972-
end function
1973-
end interface
1974-
1975-
interface __exp10f
1976-
attributes(device) real function __exp10f(r) bind(c, name='__nv_exp10f')
1977-
!dir$ ignore_tkr (d) r
1978-
real, value :: r
1979-
end function
1980-
end interface
1981-
19821990
contains
19831991

19841992
attributes(device) subroutine syncthreads()

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ end
140140
! CHECK: %{{.*}} = fir.call @__nv_brevll(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> i64
141141
! CHECK: %{{.*}} = fir.call @__nv_clz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> i32
142142
! CHECK: %{{.*}} = fir.call @__nv_clzll(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> i32
143-
! CHECK: %{{.*}} = fir.call @__nv_cosf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
143+
! CHECK: %{{.*}} = fir.call @__nv_fast_cosf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
144144
! CHECK: %{{.*}} = fir.call @__nv_ddiv_rn(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64, f64) -> f64
145145
! CHECK: %{{.*}} = fir.call @__nv_ddiv_rz(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64, f64) -> f64
146146
! CHECK: %{{.*}} = fir.call @__nv_ddiv_ru(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64, f64) -> f64
@@ -159,7 +159,7 @@ end
159159
! CHECK: %{{.*}} = fir.call @__nv_double2uint_rz(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f64) -> i32
160160
! CHECK: %{{.*}} = fir.call @__nv_mul24(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32, i32) -> i32
161161
! CHECK: %{{.*}} = fir.call @__nv_umul24(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32, i32) -> i32
162-
! CHECK: %{{.*}} = fir.call @__nv_powf(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32
162+
! CHECK: %{{.*}} = fir.call @__nv_fast_powf(%{{.*}}, %{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32
163163
! CHECK: %{{.*}} = fir.call @__nv_ull2double_rd(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> f64
164164
! CHECK: %{{.*}} = fir.call @__nv_ull2double_rn(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> f64
165165
! CHECK: %{{.*}} = fir.call @__nv_ull2double_ru(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i64) -> f64

flang/test/Lower/CUDA/cuda-libdevice.cuf

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,17 @@ attributes(global) subroutine test_log()
8383
end subroutine
8484

8585
! CHECK-LABEL: _QPtest_log
86-
! CHECK: %{{.*}} = fir.call @__nv_logf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
87-
! CHECK: %{{.*}} = fir.call @__nv_log2f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
88-
! CHECK: %{{.*}} = fir.call @__nv_log10f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
86+
! CHECK: %{{.*}} = fir.call @__nv_fast_logf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
87+
! CHECK: %{{.*}} = fir.call @__nv_fast_log2f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
88+
! CHECK: %{{.*}} = fir.call @__nv_fast_log10f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
89+
90+
attributes(global) subroutine test_sincosf()
91+
real :: r, s, c
92+
call __sincosf(r, s, c)
93+
end subroutine
94+
95+
! CHECK-LABEL: _QPtest_sincosf
96+
! CHECK: fir.call @__nv_fast_sincosf(%{{.*}}, %{{.*}}#0, %{{.*}}#0) proc_attrs<bind_c> fastmath<contract> : (f32, !fir.ref<f32>, !fir.ref<f32>) -> ()
8997

9098
attributes(global) subroutine test_sinf()
9199
real :: res
@@ -94,7 +102,7 @@ attributes(global) subroutine test_sinf()
94102
end subroutine
95103

96104
! CHECK-LABEL: _QPtest_sinf
97-
! CHECK: %{{.*}} = fir.call @__nv_sinf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
105+
! CHECK: %{{.*}} = fir.call @__nv_fast_sinf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
98106

99107
attributes(global) subroutine test_tanf()
100108
real :: res
@@ -103,7 +111,7 @@ attributes(global) subroutine test_tanf()
103111
end subroutine
104112

105113
! CHECK-LABEL: _QPtest_tanf
106-
! CHECK: %{{.*}} = fir.call @__nv_tanf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
114+
! CHECK: %{{.*}} = fir.call @__nv_fast_tanf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
107115

108116
attributes(global) subroutine test_exp()
109117
real :: res
@@ -113,8 +121,8 @@ attributes(global) subroutine test_exp()
113121
end subroutine
114122

115123
! CHECK-LABEL: _QPtest_exp
116-
! CHECK: %{{.*}} = fir.call @__nv_expf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
117-
! CHECK: %{{.*}} = fir.call @__nv_exp10f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
124+
! CHECK: %{{.*}} = fir.call @__nv_fast_expf(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
125+
! CHECK: %{{.*}} = fir.call @__nv_fast_exp10f(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (f32) -> f32
118126

119127
attributes(global) subroutine test_double2ll_rX()
120128
integer(8) :: res

0 commit comments

Comments
 (0)