@@ -343,82 +343,92 @@ end
343343@device_override Base. abs (x:: Int8 ) = ccall (" extern air.abs.s.i8" , llvmcall, Int8, (Int8,), x)
344344@device_override Base. abs (x:: UInt8 ) = ccall (" extern air.abs.u.i8" , llvmcall, UInt8, (UInt8,), x)
345345
346- @device_override Base. min (x:: Int64 ) = ccall (" extern air.min.s.i64" , llvmcall, Int64, (Int64,), x)
347- @device_override Base. min (x:: UInt64 ) = ccall (" extern air.min.u.i64" , llvmcall, UInt64, (UInt64,), x)
348- @device_override Base. min (x:: Int32 ) = ccall (" extern air.min.s.i32" , llvmcall, Int32, (Int32,), x)
349- @device_override Base. min (x:: UInt32 ) = ccall (" extern air.min.u.i32" , llvmcall, UInt32, (UInt32,), x)
350- @device_override Base. min (x:: Int16 ) = ccall (" extern air.min.s.i16" , llvmcall, Int16, (Int16,), x)
351- @device_override Base. min (x:: UInt16 ) = ccall (" extern air.min.u.i16" , llvmcall, UInt16, (UInt16,), x)
352- @device_override Base. min (x:: Int8 ) = ccall (" extern air.min.s.i8" , llvmcall, Int8, (Int8,), x)
353- @device_override Base. min (x:: UInt8 ) = ccall (" extern air.min.u.i8" , llvmcall, UInt8, (UInt8,), x)
354-
355- @device_override Base. max (x:: Int64 ) = ccall (" extern air.max.s.i64" , llvmcall, Int64, (Int64,), x)
356- @device_override Base. max (x:: UInt64 ) = ccall (" extern air.max.u.i64" , llvmcall, UInt64, (UInt64,), x)
357- @device_override Base. max (x:: Int32 ) = ccall (" extern air.max.s.i32" , llvmcall, Int32, (Int32,), x)
358- @device_override Base. max (x:: UInt32 ) = ccall (" extern air.max.u.i32" , llvmcall, UInt32, (UInt32,), x)
359- @device_override Base. max (x:: Int16 ) = ccall (" extern air.max.s.i16" , llvmcall, Int16, (Int16,), x)
360- @device_override Base. max (x:: UInt16 ) = ccall (" extern air.max.u.i16" , llvmcall, UInt16, (UInt16,), x)
361- @device_override Base. max (x:: Int8 ) = ccall (" extern air.max.s.i8" , llvmcall, Int8, (Int8,), x)
362- @device_override Base. max (x:: UInt8 ) = ccall (" extern air.max.u.i8" , llvmcall, UInt8, (UInt8,), x)
363-
364- @device_function clz (x:: Int64 ) = ccall (" extern air.clz.i64" , llvmcall, Int64, (Int64,), x)
365- @device_function clz (x:: UInt64 ) = ccall (" extern air.clz.i64" , llvmcall, UInt64, (UInt64,), x)
366- @device_function clz (x:: Int32 ) = ccall (" extern air.clz.i32" , llvmcall, Int32, (Int32,), x)
367- @device_function clz (x:: UInt32 ) = ccall (" extern air.clz.i32" , llvmcall, UInt32, (UInt32,), x)
368- @device_function clz (x:: Int16 ) = ccall (" extern air.clz.i16" , llvmcall, Int16, (Int16,), x)
369- @device_function clz (x:: UInt16 ) = ccall (" extern air.clz.i16" , llvmcall, UInt16, (UInt16,), x)
370- @device_function clz (x:: Int8 ) = ccall (" extern air.clz.i8" , llvmcall, Int8, (Int8,), x)
371- @device_function clz (x:: UInt8 ) = ccall (" extern air.clz.i8" , llvmcall, UInt8, (UInt8,), x)
372-
373- @device_function ctz (x:: Int64 ) = ccall (" extern air.ctz.i64" , llvmcall, Int64, (Int64,), x)
374- @device_function ctz (x:: UInt64 ) = ccall (" extern air.ctz.i64" , llvmcall, UInt64, (UInt64,), x)
375- @device_function ctz (x:: Int32 ) = ccall (" extern air.ctz.i32" , llvmcall, Int32, (Int32,), x)
376- @device_function ctz (x:: UInt32 ) = ccall (" extern air.ctz.i32" , llvmcall, UInt32, (UInt32,), x)
377- @device_function ctz (x:: Int16 ) = ccall (" extern air.ctz.i16" , llvmcall, Int16, (Int16,), x)
378- @device_function ctz (x:: UInt16 ) = ccall (" extern air.ctz.i16" , llvmcall, UInt16, (UInt16,), x)
379- @device_function ctz (x:: Int8 ) = ccall (" extern air.ctz.i8" , llvmcall, Int8, (Int8,), x)
380- @device_function ctz (x:: UInt8 ) = ccall (" extern air.ctz.i8" , llvmcall, UInt8, (UInt8,), x)
381-
382- @device_function popcount (x:: Int64 ) = ccall (" extern air.popcount.i64" , llvmcall, Int64, (Int64,), x)
383- @device_function popcount (x:: UInt64 ) = ccall (" extern air.popcount.i64" , llvmcall, UInt64, (UInt64,), x)
384- @device_function popcount (x:: Int32 ) = ccall (" extern air.popcount.i32" , llvmcall, Int32, (Int32,), x)
385- @device_function popcount (x:: UInt32 ) = ccall (" extern air.popcount.i32" , llvmcall, UInt32, (UInt32,), x)
386- @device_function popcount (x:: Int16 ) = ccall (" extern air.popcount.i16" , llvmcall, Int16, (Int16,), x)
387- @device_function popcount (x:: UInt16 ) = ccall (" extern air.popcount.i16" , llvmcall, UInt16, (UInt16,), x)
388- @device_function popcount (x:: Int8 ) = ccall (" extern air.popcount.i8" , llvmcall, Int8, (Int8,), x)
389- @device_function popcount (x:: UInt8 ) = ccall (" extern air.popcount.i8" , llvmcall, UInt8, (UInt8,), x)
390-
391- @device_function reverse_bits (x:: Int64 ) = ccall (" extern air.reverse_bits.i64" , llvmcall, Int64, (Int64,), x)
392- @device_function reverse_bits (x:: UInt64 ) = ccall (" extern air.reverse_bits.i64" , llvmcall, UInt64, (UInt64,), x)
393- @device_function reverse_bits (x:: Int32 ) = ccall (" extern air.reverse_bits.i32" , llvmcall, Int32, (Int32,), x)
394- @device_function reverse_bits (x:: UInt32 ) = ccall (" extern air.reverse_bits.i32" , llvmcall, UInt32, (UInt32,), x)
395- @device_function reverse_bits (x:: Int16 ) = ccall (" extern air.reverse_bits.i16" , llvmcall, Int16, (Int16,), x)
396- @device_function reverse_bits (x:: UInt16 ) = ccall (" extern air.reverse_bits.i16" , llvmcall, UInt16, (UInt16,), x)
397- @device_function reverse_bits (x:: Int8 ) = ccall (" extern air.reverse_bits.i8" , llvmcall, Int8, (Int8,), x)
398- @device_function reverse_bits (x:: UInt8 ) = ccall (" extern air.reverse_bits.i8" , llvmcall, UInt8, (UInt8,), x)
399-
400-
401- function _mulhi (a:: Int64 , b:: Int64 )
402- shift = sizeof (a) * 4
403- mask = typemax (UInt32)
404- a1, a2 = (a >> shift), a & mask
405- b1, b2 = (b >> shift), b & mask
406- a1b1, a1b2, a2b1 = a1* b1, a1* b2, a2* b1
407- t1 = a1b2 + _mulhi (a2 % UInt32, b2 % UInt32)
408- t2 = a2b1 + (t1 & mask)
409- a1b1 + (t1 >> shift) + (t2 >> shift)
410- end
411- @static if isdefined (Base. MultiplicativeInverses, :_mul_high )
412- _mulhi (a:: T , b:: T ) where {T<: Union{Signed, Unsigned} } = Base. MultiplicativeInverses. _mul_high (a, b)
413- @device_override Base. MultiplicativeInverses. _mul_high (a:: Int64 , b:: Int64 ) = _mulhi (a, b)
414- else
415- _mulhi (a:: T , b:: T ) where {T<: Union{Signed, Unsigned} } = ((widen (a)* b) >>> (sizeof (a)* 8 )) % T
416- @device_override function Base. div (a:: Int64 , b:: Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64} )
417- x = _mulhi (a, b. multiplier)
418- x += (a* b. addmul) % Int64
419- ifelse (abs (b. divisor) == 1 , a* b. divisor, (signbit (x) + (x >> b. shift)) % Int64)
420- end
421- end
346+ @device_override Base. min (x:: Int64 , y:: Int64 ) = ccall (" extern air.min.s.i64" , llvmcall, Int64, (Int64, Int64), x, y)
347+ @device_override Base. min (x:: UInt64 , y:: UInt64 ) = ccall (" extern air.min.u.i64" , llvmcall, UInt64, (UInt64, UInt64), x, y)
348+ @device_override Base. min (x:: Int32 , y:: Int32 ) = ccall (" extern air.min.s.i32" , llvmcall, Int32, (Int32, Int32), x, y)
349+ @device_override Base. min (x:: UInt32 , y:: UInt32 ) = ccall (" extern air.min.u.i32" , llvmcall, UInt32, (UInt32, UInt32), x, y)
350+ @device_override Base. min (x:: Int16 , y:: Int16 ) = ccall (" extern air.min.s.i16" , llvmcall, Int16, (Int16, Int16), x, y)
351+ @device_override Base. min (x:: UInt16 , y:: UInt16 ) = ccall (" extern air.min.u.i16" , llvmcall, UInt16, (UInt16, UInt16), x, y)
352+ @device_override Base. min (x:: Int8 , y:: Int8 ) = ccall (" extern air.min.s.i8" , llvmcall, Int8, (Int8, Int8), x, y)
353+ @device_override Base. min (x:: UInt8 , y:: UInt8 ) = ccall (" extern air.min.u.i8" , llvmcall, UInt8, (UInt8, UInt8), x, y)
354+
355+ # XXX : Breaks mul! when uncommented. MWE: using Revise, Metal;A, x = mtl(rand(Int32, 4, 4)), mtl(rand(Int32, 4)); A*x
356+ # @device_override Base.max(x::Int64, y::Int64) = ccall("extern air.max.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
357+ @device_override Base. max (x:: UInt64 , y:: UInt64 ) = ccall (" extern air.max.u.i64" , llvmcall, UInt64, (UInt64, UInt64), x, y)
358+ @device_override Base. max (x:: Int32 , y:: Int32 ) = ccall (" extern air.max.s.i32" , llvmcall, Int32, (Int32, Int32), x, y)
359+ @device_override Base. max (x:: UInt32 , y:: UInt32 ) = ccall (" extern air.max.u.i32" , llvmcall, UInt32, (UInt32, UInt32), x, y)
360+ @device_override Base. max (x:: Int16 , y:: Int16 ) = ccall (" extern air.max.s.i16" , llvmcall, Int16, (Int16, Int16), x, y)
361+ @device_override Base. max (x:: UInt16 , y:: UInt16 ) = ccall (" extern air.max.u.i16" , llvmcall, UInt16, (UInt16, UInt16), x, y)
362+ @device_override Base. max (x:: Int8 , y:: Int8 ) = ccall (" extern air.max.s.i8" , llvmcall, Int8, (Int8, Int8), x, y)
363+ @device_override Base. max (x:: UInt8 , y:: UInt8 ) = ccall (" extern air.max.u.i8" , llvmcall, UInt8, (UInt8, UInt8), x, y)
364+
365+ @device_override Base. min (x:: Int64 , y:: Int64 , z:: Int64 ) = ccall (" extern air.min3.s.i64" , llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
366+ @device_override Base. min (x:: UInt64 , y:: UInt64 , z:: UInt64 ) = ccall (" extern air.min3.u.i64" , llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
367+ @device_override Base. min (x:: Int32 , y:: Int32 , z:: Int32 ) = ccall (" extern air.min3.s.i32" , llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
368+ @device_override Base. min (x:: UInt32 , y:: UInt32 , z:: UInt32 ) = ccall (" extern air.min3.u.i32" , llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
369+ @device_override Base. min (x:: Int16 , y:: Int16 , z:: Int16 ) = ccall (" extern air.min3.s.i16" , llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
370+ @device_override Base. min (x:: UInt16 , y:: UInt16 , z:: UInt16 ) = ccall (" extern air.min3.u.i16" , llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
371+ @device_override Base. min (x:: Int8 , y:: Int8 , z:: Int8 ) = ccall (" extern air.min3.s.i8" , llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
372+ @device_override Base. min (x:: UInt8 , y:: UInt8 , z:: UInt8 ) = ccall (" extern air.min3.u.i8" , llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
373+
374+ @device_override Base. max (x:: Int64 , y:: Int64 , z:: Int64 ) = ccall (" extern air.max3.s.i64" , llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
375+ @device_override Base. max (x:: UInt64 , y:: UInt64 , z:: UInt64 ) = ccall (" extern air.max3.u.i64" , llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
376+ @device_override Base. max (x:: Int32 , y:: Int32 , z:: Int32 ) = ccall (" extern air.max3.s.i32" , llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
377+ @device_override Base. max (x:: UInt32 , y:: UInt32 , z:: UInt32 ) = ccall (" extern air.max3.u.i32" , llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
378+ @device_override Base. max (x:: Int16 , y:: Int16 , z:: Int16 ) = ccall (" extern air.max3.s.i16" , llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
379+ @device_override Base. max (x:: UInt16 , y:: UInt16 , z:: UInt16 ) = ccall (" extern air.max3.u.i16" , llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
380+ @device_override Base. max (x:: Int8 , y:: Int8 , z:: Int8 ) = ccall (" extern air.max3.s.i8" , llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
381+ @device_override Base. max (x:: UInt8 , y:: UInt8 , z:: UInt8 ) = ccall (" extern air.max3.u.i8" , llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
382+
383+ @device_override Base. leading_zeros (x:: Int64 ) = ccall (" extern air.clz.i64" , llvmcall, Int64, (Int64,), x)
384+ @device_override Base. leading_zeros (x:: UInt64 ) = ccall (" extern air.clz.i64" , llvmcall, UInt64, (UInt64,), x)
385+ @device_override Base. leading_zeros (x:: Int32 ) = ccall (" extern air.clz.i32" , llvmcall, Int32, (Int32,), x)
386+ @device_override Base. leading_zeros (x:: UInt32 ) = ccall (" extern air.clz.i32" , llvmcall, UInt32, (UInt32,), x)
387+ @device_override Base. leading_zeros (x:: Int16 ) = ccall (" extern air.clz.i16" , llvmcall, Int16, (Int16,), x)
388+ @device_override Base. leading_zeros (x:: UInt16 ) = ccall (" extern air.clz.i16" , llvmcall, UInt16, (UInt16,), x)
389+ @device_override Base. leading_zeros (x:: Int8 ) = ccall (" extern air.clz.i8" , llvmcall, Int8, (Int8,), x)
390+ @device_override Base. leading_zeros (x:: UInt8 ) = ccall (" extern air.clz.i8" , llvmcall, UInt8, (UInt8,), x)
391+ const clz = leading_zeros
392+
393+ @device_override Base. trailing_zeros (x:: Int64 ) = ccall (" extern air.ctz.i64" , llvmcall, Int64, (Int64,), x)
394+ @device_override Base. trailing_zeros (x:: UInt64 ) = ccall (" extern air.ctz.i64" , llvmcall, UInt64, (UInt64,), x)
395+ @device_override Base. trailing_zeros (x:: Int32 ) = ccall (" extern air.ctz.i32" , llvmcall, Int32, (Int32,), x)
396+ @device_override Base. trailing_zeros (x:: UInt32 ) = ccall (" extern air.ctz.i32" , llvmcall, UInt32, (UInt32,), x)
397+ @device_override Base. trailing_zeros (x:: Int16 ) = ccall (" extern air.ctz.i16" , llvmcall, Int16, (Int16,), x)
398+ @device_override Base. trailing_zeros (x:: UInt16 ) = ccall (" extern air.ctz.i16" , llvmcall, UInt16, (UInt16,), x)
399+ @device_override Base. trailing_zeros (x:: Int8 ) = ccall (" extern air.ctz.i8" , llvmcall, Int8, (Int8,), x)
400+ @device_override Base. trailing_zeros (x:: UInt8 ) = ccall (" extern air.ctz.i8" , llvmcall, UInt8, (UInt8,), x)
401+ const ctz = trailing_zeros
402+
403+ @device_override Base. count_ones (x:: Int64 ) = ccall (" extern air.popcount.i64" , llvmcall, Int64, (Int64,), x)
404+ @device_override Base. count_ones (x:: UInt64 ) = ccall (" extern air.popcount.i64" , llvmcall, UInt64, (UInt64,), x)
405+ @device_override Base. count_ones (x:: Int32 ) = ccall (" extern air.popcount.i32" , llvmcall, Int32, (Int32,), x)
406+ @device_override Base. count_ones (x:: UInt32 ) = ccall (" extern air.popcount.i32" , llvmcall, UInt32, (UInt32,), x)
407+ @device_override Base. count_ones (x:: Int16 ) = ccall (" extern air.popcount.i16" , llvmcall, Int16, (Int16,), x)
408+ @device_override Base. count_ones (x:: UInt16 ) = ccall (" extern air.popcount.i16" , llvmcall, UInt16, (UInt16,), x)
409+ @device_override Base. count_ones (x:: Int8 ) = ccall (" extern air.popcount.i8" , llvmcall, Int8, (Int8,), x)
410+ @device_override Base. count_ones (x:: UInt8 ) = ccall (" extern air.popcount.i8" , llvmcall, UInt8, (UInt8,), x)
411+ const popcount = count_ones
412+
413+ @device_override Base. bitreverse (x:: Int64 ) = ccall (" extern air.reverse_bits.i64" , llvmcall, Int64, (Int64,), x)
414+ @device_override Base. bitreverse (x:: UInt64 ) = ccall (" extern air.reverse_bits.i64" , llvmcall, UInt64, (UInt64,), x)
415+ @device_override Base. bitreverse (x:: Int32 ) = ccall (" extern air.reverse_bits.i32" , llvmcall, Int32, (Int32,), x)
416+ @device_override Base. bitreverse (x:: UInt32 ) = ccall (" extern air.reverse_bits.i32" , llvmcall, UInt32, (UInt32,), x)
417+ @device_override Base. bitreverse (x:: Int16 ) = ccall (" extern air.reverse_bits.i16" , llvmcall, Int16, (Int16,), x)
418+ @device_override Base. bitreverse (x:: UInt16 ) = ccall (" extern air.reverse_bits.i16" , llvmcall, UInt16, (UInt16,), x)
419+ @device_override Base. bitreverse (x:: Int8 ) = ccall (" extern air.reverse_bits.i8" , llvmcall, Int8, (Int8,), x)
420+ @device_override Base. bitreverse (x:: UInt8 ) = ccall (" extern air.reverse_bits.i8" , llvmcall, UInt8, (UInt8,), x)
421+ const reverse_bits = bitreverse
422+
423+ @device_override Base. MultiplicativeInverses. _mul_high (x:: Int64 , y:: Int64 ) = ccall (" extern air.mul_hi.s.i64" , llvmcall, Int64, (Int64, Int64), x, y)
424+ @device_override Base. MultiplicativeInverses. _mul_high (x:: UInt64 , y:: UInt64 ) = ccall (" extern air.mul_hi.u.i64" , llvmcall, UInt64, (UInt64, UInt64), x, y)
425+ @device_override Base. MultiplicativeInverses. _mul_high (x:: Int32 , y:: Int32 ) = ccall (" extern air.mul_hi.s.i32" , llvmcall, Int32, (Int32, Int32), x, y)
426+ @device_override Base. MultiplicativeInverses. _mul_high (x:: UInt32 , y:: UInt32 ) = ccall (" extern air.mul_hi.u.i32" , llvmcall, UInt32, (UInt32, UInt32), x, y)
427+ @device_override Base. MultiplicativeInverses. _mul_high (x:: Int16 , y:: Int16 ) = ccall (" extern air.mul_hi.s.i16" , llvmcall, Int16, (Int16, Int16), x, y)
428+ @device_override Base. MultiplicativeInverses. _mul_high (x:: UInt16 , y:: UInt16 ) = ccall (" extern air.mul_hi.u.i16" , llvmcall, UInt16, (UInt16, UInt16), x, y)
429+ @device_override Base. MultiplicativeInverses. _mul_high (x:: Int8 , y:: Int8 ) = ccall (" extern air.mul_hi.s.i8" , llvmcall, Int8, (Int8, Int8), x, y)
430+ @device_override Base. MultiplicativeInverses. _mul_high (x:: UInt8 , y:: UInt8 ) = ccall (" extern air.mul_hi.u.i8" , llvmcall, UInt8, (UInt8, UInt8), x, y)
431+ const mulhi = Base. MultiplicativeInverses. _mul_high
422432
423433# From: https://forums.developer.nvidia.com/t/a-faster-and-more-accurate-implementation-of-expm1f/48085/2
424434# Original license copied below:
495505 end
496506
497507 return r
498- end
508+ end
0 commit comments