Skip to content

Commit ae0540b

Browse files
authored
ZLUDA v3.9.1 (#84)
* Handle ptx kernel parameter attribute .ptr. * Add cublasStrmm_v2, cublasSgeam. * Handle MIOpen internal error. * Implement cuStreamGetPriority. * Add cudnnBatchNormalizationForwardInference. * Fix nvrtc.
1 parent 4d14bf9 commit ae0540b

File tree

10 files changed

+291
-200
lines changed

10 files changed

+291
-200
lines changed

ptx/src/ast.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ pub struct Function<'a, ID, S> {
144144
pub body: Option<Vec<S>>,
145145
}
146146

147+
pub enum KernelParameterAttribute {
148+
Pointer(StateSpace),
149+
}
150+
147151
#[derive(PartialEq, Eq, Clone, Hash)]
148152
pub enum Type {
149153
// .param.b32 foo;

ptx/src/ptx.lalrpop

Lines changed: 53 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ match {
124124
".popc",
125125
".pragma",
126126
".pred",
127+
".ptr",
127128
".r",
128129
".red",
129130
".reg",
@@ -566,6 +567,13 @@ KernelArguments: Vec<ast::VariableDeclaration<&'input str>> = {
566567
"(" <args:Comma<VariableDeclarationEntry>> ")" => args
567568
};
568569

570+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#kernel-function-parameter-attributes
571+
KernelParameterAttribute: ast::KernelParameterAttribute = {
572+
".ptr" <state_space:StateSpaceSpecifier?> => {
573+
ast::KernelParameterAttribute::Pointer(state_space.unwrap_or(ast::StateSpace::Generic))
574+
}
575+
};
576+
569577
FnArguments: Vec<ast::VariableDeclaration<&'input str>> = {
570578
"(" <args:Comma<VariableDeclarationFunc>> ")" => args
571579
};
@@ -662,21 +670,21 @@ InitializerNoAdd: ast::Initializer<&'input str> = {
662670
}
663671

664672
VariableDeclarationFunc: ast::VariableDeclaration<&'input str> = {
665-
<var:VariableDeclarationBase> =>
673+
<var:VariableDeclaration> =>
666674
validate_variable_declaration_func(var, errors)
667675
}
668676

669677
VariableDeclarationEntry: ast::VariableDeclaration<&'input str> = {
670-
<var:VariableDeclarationBase> =>
678+
<var:VariableDeclaration> =>
671679
validate_variable_declaration_entry(var, errors)
672680
}
673681

674682
VariableDeclarationProto: ast::VariableDeclaration<&'input str> = {
675-
<var:VariableDeclarationBase> =>
683+
<var:VariableDeclaration> =>
676684
validate_variable_declaration_proto(var, errors)
677685
}
678686

679-
VariableDeclarationBase: ast::VariableDeclaration<&'input str> = {
687+
VariableDeclaration: ast::VariableDeclaration<&'input str> = {
680688
<variable:VariableDeclarationBegin> <name:ExtendedIDOrBlank> <dims:ArrayDimensions?> => {
681689
let mut variable = variable.clone();
682690
variable.name = name;
@@ -710,16 +718,54 @@ MultiVariableDefinition: Vec<ast::MultiVariableDefinition<&'input str>> = {
710718
}
711719

712720
VariableDeclarationBegin: ast::VariableDeclaration<&'input str> = {
713-
<state_space:StateSpaceSpecifier> <align:Align?> <type_:AnyType> => {
714-
ast::VariableDeclaration {
721+
<state_space:StateSpaceSpecifier> <align:Align> <v_len:VectorPrefix?> <type_:ScalarType> => {
722+
let mut variable = ast::VariableDeclaration {
723+
align: Some(align),
724+
type_: ast::Type::Scalar(type_),
725+
state_space,
726+
name: ""
727+
};
728+
if let Some(v_len) = v_len {
729+
variable.type_ = ast::Type::Vector(type_, v_len)
730+
}
731+
variable
732+
},
733+
<state_space:StateSpaceSpecifier> <v_len:VectorPrefix?> <type_:ScalarType> <attr:KernelParameterAttribute?> <align:Align?> => {
734+
let mut variable = ast::VariableDeclaration {
715735
align,
716-
type_,
736+
type_: ast::Type::Scalar(type_),
737+
state_space,
738+
name: ""
739+
};
740+
if let Some(v_len) = v_len {
741+
variable.type_ = ast::Type::Vector(type_, v_len)
742+
}
743+
if let Some(attr) = attr {
744+
variable.type_ = match attr {
745+
ast::KernelParameterAttribute::Pointer(state_space) => {
746+
ast::Type::Pointer(type_, state_space)
747+
}
748+
};
749+
}
750+
variable
751+
},
752+
<state_space:StateSpaceSpecifier> <var:VariableDeclarationBeginNonScalarTypeWithAlign> => {
753+
ast::VariableDeclaration {
754+
align: var.1,
755+
type_: var.0,
717756
state_space,
718757
name: ""
719758
}
720759
}
721760
}
722761

762+
VariableDeclarationBeginNonScalarTypeWithAlign: (ast::Type, Option<u32>) = {
763+
<align:Align> ".texref" => (ast::Type::Texref, Some(align)),
764+
".texref" <align:Align?> => (ast::Type::Texref, align),
765+
<align:Align> ".surfref" => (ast::Type::Surfref, Some(align)),
766+
".surfref" <align:Align?> => (ast::Type::Surfref, align)
767+
}
768+
723769
VariableDefinitionOnce: (&'input str, Option<Either<u32, Vec<u32>>>, Option<ast::Initializer<&'input str>>) = {
724770
<name:ExtendedIDOrBlank> <suffix:VariableDefinitionSuffix?> <init:VariableDefinitionInitializer?> => (name, suffix, init)
725771
}
@@ -734,13 +780,6 @@ VariableDefinitionInitializer: ast::Initializer<&'input str> = {
734780
"=" <init:Initializer> => init
735781
}
736782

737-
AnyType: ast::Type = {
738-
".texref" => ast::Type::Texref,
739-
".surfref" => ast::Type::Surfref,
740-
<v_len:VectorPrefix> <type_:ScalarType> => ast::Type::Vector(type_, v_len),
741-
<type_:ScalarType> => ast::Type::Scalar(type_),
742-
}
743-
744783
#[inline]
745784
SizedScalarType: ast::ScalarType = {
746785
".b8" => ast::ScalarType::B8,
@@ -2780,58 +2819,6 @@ AnyBitType: ast::ScalarType = {
27802819
".b64" => ast::ScalarType::B64,
27812820
};
27822821

2783-
VariableScalarUnitialized<T>: (Option<u32>, T, &'input str) = {
2784-
<align:Align?> <type_:T> <name:ExtendedID> => {
2785-
(align, type_, name)
2786-
}
2787-
}
2788-
2789-
VariableScalar<T>: (Option<u32>, T, &'input str, Vec<u8>) = {
2790-
<align:Align?> <type_:T> <name:ExtendedID> <init:VariableInitalizer?> => {
2791-
let initializer = init.map(ast::ImmediateValue::to_bytes).unwrap_or(Vec::new());
2792-
(align, type_, name, initializer)
2793-
}
2794-
}
2795-
2796-
VariableInitalizer: ast::ImmediateValue = {
2797-
"=" <v:ImmediateValue> => v
2798-
}
2799-
2800-
VariableVector<T>: (Option<u32>, u8, T, &'input str) = {
2801-
<align:Align?> <v_len:VectorPrefix> <type_:T> <name:ExtendedID> => {
2802-
(align, v_len, type_, name)
2803-
}
2804-
}
2805-
2806-
// empty dimensions [0] means it's a pointer
2807-
VariableArrayOrPointer<T>: (Option<u32>, T, &'input str, ast::ArrayOrPointer) = {
2808-
<align:Align?> <typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimensions> <init:ArrayInitializer?> => {
2809-
let mut dims = dims;
2810-
let array_init = match init {
2811-
Some(init) => {
2812-
let init_vec = match init.to_vec(typ, &mut dims) {
2813-
Err(error) => {
2814-
errors.push(ParseError::User { error });
2815-
Vec::new()
2816-
}
2817-
Ok(x) => x
2818-
};
2819-
ast::ArrayOrPointer::Array { dimensions: dims, init: init_vec }
2820-
}
2821-
None => {
2822-
if dims.len() > 1 && dims.contains(&0) {
2823-
errors.push(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
2824-
}
2825-
match &*dims {
2826-
[0] => ast::ArrayOrPointer::Pointer,
2827-
_ => ast::ArrayOrPointer::Array { dimensions: dims, init: Vec::new() }
2828-
}
2829-
}
2830-
};
2831-
(align, typ, name, array_init)
2832-
}
2833-
}
2834-
28352822
// [0] and [] are treated the same
28362823
ArrayDimensions: Vec<u32> = {
28372824
ArrayEmptyDimension => vec![0u32],

zluda/src/cuda.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ cuda_function_declarations!(
144144
cuStreamGetCtx,
145145
cuStreamGetCtx_ptsz,
146146
cuStreamGetFlags,
147+
cuStreamGetPriority,
147148
cuStreamIsCapturing,
148149
cuStreamQuery,
149150
cuStreamSynchronize,
@@ -1087,6 +1088,13 @@ mod definitions {
10871088
stream::get_flags(hStream, flags)
10881089
}
10891090

1091+
pub(crate) unsafe fn cuStreamGetPriority(
1092+
hStream: *mut stream::Stream,
1093+
priority: *mut ::std::os::raw::c_int,
1094+
) -> Result<(), CUresult> {
1095+
stream::get_priority(hStream, priority)
1096+
}
1097+
10901098
pub(crate) unsafe fn cuStreamIsCapturing(
10911099
hStream: *mut stream::Stream,
10921100
captureStatus: *mut hipStreamCaptureStatus,
@@ -1662,9 +1670,7 @@ mod definitions {
16621670
array::mipmapped_create(pHandle, pMipmappedArrayDesc, numMipmapLevels)
16631671
}
16641672

1665-
pub(crate) unsafe fn cuMipmappedArrayDestroy(
1666-
hMipmappedArray: CUmipmappedArray,
1667-
) -> hipError_t {
1673+
pub(crate) unsafe fn cuMipmappedArrayDestroy(hMipmappedArray: CUmipmappedArray) -> hipError_t {
16681674
array::mipmapped_destroy(hMipmappedArray)
16691675
}
16701676

zluda/src/impl/stream.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ pub(crate) unsafe fn get_flags(stream: *mut Stream, flags: *mut u32) -> Result<(
181181
Ok(())
182182
}
183183

184+
pub(crate) unsafe fn get_priority(stream: *mut Stream, priority: *mut i32) -> Result<(), CUresult> {
185+
let hip_stream = as_hip_stream(stream)?;
186+
hip_call_cuda! { hipStreamGetPriority(hip_stream, priority) };
187+
Ok(())
188+
}
189+
184190
pub(crate) unsafe fn is_capturing(
185191
stream: *mut Stream,
186192
capture_status: *mut hipStreamCaptureStatus,

zluda_blas/src/cublas.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3828,7 +3828,7 @@ pub extern "system" fn cublasZtrsm_v2(
38283828
}
38293829

38303830
#[no_mangle]
3831-
pub extern "system" fn cublasStrmm_v2(
3831+
pub unsafe extern "system" fn cublasStrmm_v2(
38323832
handle: cublasHandle_t,
38333833
side: cublasSideMode_t,
38343834
uplo: cublasFillMode_t,
@@ -3844,7 +3844,9 @@ pub extern "system" fn cublasStrmm_v2(
38443844
C: *mut f32,
38453845
ldc: ::std::os::raw::c_int,
38463846
) -> cublasStatus_t {
3847-
crate::unsupported()
3847+
crate::strmm(
3848+
handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc,
3849+
)
38483850
}
38493851

38503852
#[no_mangle]
@@ -4292,7 +4294,7 @@ pub extern "system" fn cublasHgemmStridedBatched(
42924294
}
42934295

42944296
#[no_mangle]
4295-
pub extern "system" fn cublasSgeam(
4297+
pub unsafe extern "system" fn cublasSgeam(
42964298
handle: cublasHandle_t,
42974299
transa: cublasOperation_t,
42984300
transb: cublasOperation_t,
@@ -4307,7 +4309,9 @@ pub extern "system" fn cublasSgeam(
43074309
C: *mut f32,
43084310
ldc: ::std::os::raw::c_int,
43094311
) -> cublasStatus_t {
4310-
crate::unsupported()
4312+
crate::sgeam(
4313+
handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc,
4314+
)
43114315
}
43124316

43134317
#[no_mangle]

zluda_blas/src/lib.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,40 @@ fn to_compute_type(compute_type: cublasComputeType_t) -> rocblas_datatype {
489489
}
490490
}
491491

492+
unsafe fn sgeam(
493+
handle: *mut cublasContext,
494+
transa: cublasOperation_t,
495+
transb: cublasOperation_t,
496+
m: i32,
497+
n: i32,
498+
alpha: *const f32,
499+
a: *const f32,
500+
lda: i32,
501+
beta: *const f32,
502+
b: *const f32,
503+
ldb: i32,
504+
c: *mut f32,
505+
ldc: i32,
506+
) -> cublasStatus_t {
507+
let transa = op_from_cuda(transa);
508+
let transb = op_from_cuda(transb);
509+
to_cuda(rocblas_sgeam(
510+
handle.cast(),
511+
transa,
512+
transb,
513+
m,
514+
n,
515+
alpha,
516+
a,
517+
lda,
518+
beta,
519+
b,
520+
ldb,
521+
c,
522+
ldc,
523+
))
524+
}
525+
492526
unsafe fn zgemm_strided_batch(
493527
handle: *mut cublasContext,
494528
transa: cublasOperation_t,
@@ -1146,6 +1180,44 @@ unsafe fn dtrsm(
11461180
))
11471181
}
11481182

1183+
unsafe fn strmm(
1184+
handle: *mut cublasContext,
1185+
side: cublasSideMode_t,
1186+
uplo: cublasFillMode_t,
1187+
trans: cublasOperation_t,
1188+
diag: cublasDiagType_t,
1189+
m: i32,
1190+
n: i32,
1191+
alpha: *const f32,
1192+
a: *const f32,
1193+
lda: i32,
1194+
b: *const f32,
1195+
ldb: i32,
1196+
c: *mut f32,
1197+
ldc: i32,
1198+
) -> cublasStatus_t {
1199+
let side = to_side(side);
1200+
let uplo = to_fill(uplo);
1201+
let trans = op_from_cuda(trans);
1202+
let diag = to_diag(diag);
1203+
to_cuda(rocblas_strmm(
1204+
handle.cast(),
1205+
side,
1206+
uplo,
1207+
trans,
1208+
diag,
1209+
m,
1210+
n,
1211+
alpha,
1212+
a,
1213+
lda,
1214+
b,
1215+
ldb,
1216+
c,
1217+
ldc,
1218+
))
1219+
}
1220+
11491221
unsafe fn gemm_batched_ex(
11501222
handle: cublasHandle_t,
11511223
transa: cublasOperation_t,

zluda_dnn/src/cudnn.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,7 +2729,7 @@ impl cudnnBatchNormOps_t {
27292729
pub struct cudnnBatchNormOps_t(pub ::std::os::raw::c_int);
27302730

27312731
#[no_mangle]
2732-
pub extern "system" fn cudnnBatchNormalizationForwardInference(
2732+
pub unsafe extern "system" fn cudnnBatchNormalizationForwardInference(
27332733
handle: cudnnHandle_t,
27342734
mode: cudnnBatchNormMode_t,
27352735
alpha: *const ::std::os::raw::c_void,
@@ -2745,7 +2745,22 @@ pub extern "system" fn cudnnBatchNormalizationForwardInference(
27452745
estimatedVariance: *const ::std::os::raw::c_void,
27462746
epsilon: f64,
27472747
) -> cudnnStatus_t {
2748-
crate::unsupported()
2748+
crate::batch_normalization_forward_inference(
2749+
handle,
2750+
mode,
2751+
alpha,
2752+
beta,
2753+
xDesc,
2754+
x,
2755+
yDesc,
2756+
y,
2757+
bnScaleBiasMeanVarDesc,
2758+
bnScale,
2759+
bnBias,
2760+
estimatedMean,
2761+
estimatedVariance,
2762+
epsilon,
2763+
)
27492764
}
27502765
impl cudnnNormMode_t {
27512766
pub const CUDNN_NORM_PER_ACTIVATION: cudnnNormMode_t = cudnnNormMode_t(0);

0 commit comments

Comments
 (0)