|
1 | 1 | use crate::{
|
| 2 | + convolution_descriptor::ConvolutionDescriptor, |
| 3 | + convolution_fwd_algo::{BestHeuristic, ConvolutionFwdAlgo, SupportedConvFwd}, |
2 | 4 | data_type::*,
|
3 | 5 | error::{CudnnError, IntoResult},
|
| 6 | + filter_descriptor::FilterDescriptor, |
4 | 7 | nan_propagation::*,
|
5 | 8 | op_tensor_descriptor::*,
|
6 | 9 | sys,
|
@@ -549,6 +552,249 @@ impl CudnnContext {
|
549 | 552 | .into_result()
|
550 | 553 | }
|
551 | 554 | }
|
| 555 | + |
| 556 | + /// This function serves as a heuristic for obtaining the best suited algorithm for |
| 557 | + /// `cudnnConvolutionForward()` for the given layer specifications. |
| 558 | + /// |
| 559 | + /// It will return the best algorithm according to an internal heuristic. |
| 560 | + /// |
| 561 | + /// # Arguments |
| 562 | + /// |
| 563 | + /// * `x_desc` - previously initialized tensor descriptor for the input map. |
| 564 | + /// |
| 565 | + /// * `w_desc` - previously initialized tensor descriptor for the filter map. |
| 566 | + /// |
| 567 | + /// * `y_desc` - previously initialized tensor descriptor for the output map. |
| 568 | + /// |
| 569 | + /// * `conv_desc` - previously initialized convolution descriptor. |
| 570 | + /// |
| 571 | + /// **Do note** that the best found algorithm `MathType` and the one supplied to the convolution |
| 572 | + /// descriptor's at its creation may differ, for this reason you should always manually set the |
| 573 | + /// math type of the convolution descriptor according to the one of the returned algorithm, as |
| 574 | + /// pictured in the following example. |
| 575 | + /// |
| 576 | + /// # Examples |
| 577 | + /// |
| 578 | + /// ``` |
| 579 | + /// # use std::error::Error; |
| 580 | + /// # |
| 581 | + /// # fn main() -> Result<(), Box<dyn Error>> { |
| 582 | + /// use cudnn::{ |
| 583 | + /// ConvolutionDescriptor, ConvolutionMode, CudnnContext, FilterDescriptor, MathType, |
| 584 | + /// TensorDescriptor, NCHW, |
| 585 | + /// }; |
| 586 | + /// |
| 587 | + /// let ctx = CudnnContext::new()?; |
| 588 | + /// |
| 589 | + /// let padding = [0, 0]; |
| 590 | + /// let stride = [1, 1]; |
| 591 | + /// let dilation = [1, 1]; |
| 592 | + /// let groups = 1; |
| 593 | + /// let mode = ConvolutionMode::CrossCorrelation; |
| 594 | + /// let math_type = MathType::Default; |
| 595 | + /// |
| 596 | + /// // 2-dimensional convolution. |
| 597 | + /// let mut conv_desc = ConvolutionDescriptor::<f32, 2>::new(padding, stride, dilation, groups, mode, math_type)?; |
| 598 | + /// |
| 599 | + /// let input_desc = TensorDescriptor::<f32, _, 4>::new([3, 2, 5, 5,], NCHW)?; |
| 600 | + /// let filter_desc = FilterDescriptor::<f32, _, 4>::new([3, 2, 2, 2], NCHW)?; |
| 601 | + /// let output_desc = TensorDescriptor::<f32, _, 4>::new([3, 3, 4, 4], NCHW)?; |
| 602 | + /// |
| 603 | + /// let algo = ctx.get_convolution_forward_algorithm(&input_desc, &filter_desc, &output_desc, &conv_desc)?; |
| 604 | + /// |
| 605 | + /// conv_desc.set_math_type(algo.math_type())?; |
| 606 | + /// # Ok(()) |
| 607 | + /// # } |
| 608 | + /// ``` |
| 609 | + pub fn get_convolution_forward_algorithm< |
| 610 | + InType, |
| 611 | + InFmt, |
| 612 | + FilterType, |
| 613 | + FilterFmt, |
| 614 | + CompType, |
| 615 | + OutType, |
| 616 | + OutFmt, |
| 617 | + const D: usize, |
| 618 | + const N: usize, |
| 619 | + >( |
| 620 | + &self, |
| 621 | + x_desc: &TensorDescriptor<InType, InFmt, D>, |
| 622 | + w_desc: &FilterDescriptor<FilterType, FilterFmt, D>, |
| 623 | + y_desc: &TensorDescriptor<OutType, OutFmt, D>, |
| 624 | + conv_desc: &ConvolutionDescriptor<CompType, N>, |
| 625 | + ) -> Result<BestHeuristic, CudnnError> |
| 626 | + where |
| 627 | + InType: DataType, |
| 628 | + InFmt: TensorFormat + SupportedType<InType>, |
| 629 | + FilterType: DataType, |
| 630 | + FilterFmt: TensorFormat + SupportedType<FilterType>, |
| 631 | + CompType: DataType, |
| 632 | + OutType: DataType, |
| 633 | + OutFmt: TensorFormat + SupportedType<OutType>, |
| 634 | + BestHeuristic: |
| 635 | + SupportedConvFwd<InType, InFmt, FilterType, FilterFmt, CompType, OutType, OutFmt, D, N>, |
| 636 | + { |
| 637 | + let mut returned_algo_count = MaybeUninit::uninit(); |
| 638 | + let mut perf_results = MaybeUninit::uninit(); |
| 639 | + |
| 640 | + unsafe { |
| 641 | + sys::cudnnGetConvolutionForwardAlgorithm_v7( |
| 642 | + self.raw, |
| 643 | + x_desc.raw, |
| 644 | + w_desc.raw, |
| 645 | + conv_desc.raw, |
| 646 | + y_desc.raw, |
| 647 | + 1, |
| 648 | + returned_algo_count.as_mut_ptr(), |
| 649 | + perf_results.as_mut_ptr(), |
| 650 | + ) |
| 651 | + .into_result()?; |
| 652 | + |
| 653 | + let returned_algo_count = returned_algo_count.assume_init(); |
| 654 | + |
| 655 | + match returned_algo_count { |
| 656 | + // This is general enough so that in the future it can be expanded to be more |
| 657 | + // complex. |
| 658 | + 1 => { |
| 659 | + let results: Vec<BestHeuristic> = { |
| 660 | + let raw_results = std::slice::from_raw_parts( |
| 661 | + perf_results.as_ptr(), |
| 662 | + returned_algo_count as usize, |
| 663 | + ); |
| 664 | + |
| 665 | + raw_results |
| 666 | + .iter() |
| 667 | + .copied() |
| 668 | + .map(BestHeuristic::try_from) |
| 669 | + .filter_map(Result::ok) |
| 670 | + .collect() |
| 671 | + }; |
| 672 | + |
| 673 | + let algo = results[0]; |
| 674 | + |
| 675 | + Ok(algo) |
| 676 | + } |
| 677 | + _ => return Err(CudnnError::BadParam), |
| 678 | + } |
| 679 | + } |
| 680 | + } |
| 681 | + |
| 682 | + /// This function returns the amount of GPU memory workspace the user needs to allocate to be |
| 683 | + /// able to call `cudnnConvolutionForward()` with the specified algorithm. The workspace |
| 684 | + /// allocated will then be passed to the routine `cudnnConvolutionForward()`. |
| 685 | + /// |
| 686 | + /// The specified algorithm can be the result of the call to |
| 687 | + /// [`get_convolution_forward_algorithm`](crate::CudnnContext::get_convolution_forward_algorithm) |
| 688 | + /// or can be chosen arbitrarily by the user. In the latter case workspace size can be directly |
| 689 | + /// obtained by calling [`workspace_size`](crate::BestHeuristic::workspace_size) on the returned |
| 690 | + /// algorithm. |
| 691 | + /// |
| 692 | + /// **Do note** that not every algorithm is available for every configuration of the input |
| 693 | + /// tensor and/or every configuration of the convolution descriptor. |
| 694 | + /// |
| 695 | + /// # Arguments |
| 696 | + /// |
| 697 | + /// * `x_desc` - previously initialized tensor descriptor for the input map. |
| 698 | + /// |
| 699 | + /// * `w_desc` - previously initialized tensor descriptor for the filter map. |
| 700 | + /// |
| 701 | + /// * `y_desc` - previously initialized tensor descriptor for the output map. |
| 702 | + /// |
| 703 | + /// * `conv_desc` - previously initialized convolution descriptor. |
| 704 | + /// |
| 705 | + /// * `algo` - chosen convolution algorithm. |
| 706 | + /// |
| 707 | + /// # Examples |
| 708 | + /// |
| 709 | + /// ``` |
| 710 | + /// # use std::error::Error; |
| 711 | + /// # |
| 712 | + /// # fn main() -> Result<(), Box<dyn Error>> { |
| 713 | + /// use cudnn::{ |
| 714 | + /// ConvolutionDescriptor, ConvolutionMode, CudnnContext, FilterDescriptor, |
| 715 | + /// ImplicitPrecompGemm, MathType, TensorDescriptor, NCHW, |
| 716 | + /// }; |
| 717 | + /// use cust::memory::DeviceBuffer; |
| 718 | + /// |
| 719 | + /// let ctx = CudnnContext::new()?; |
| 720 | + /// |
| 721 | + /// let padding = [0, 0]; |
| 722 | + /// let stride = [1, 1]; |
| 723 | + /// let dilation = [1, 1]; |
| 724 | + /// let groups = 1; |
| 725 | + /// let mode = ConvolutionMode::CrossCorrelation; |
| 726 | + /// let math_type = MathType::Default; |
| 727 | + /// |
| 728 | + /// // 2-dimensional convolution. |
| 729 | + /// let mut conv_desc = |
| 730 | + /// ConvolutionDescriptor::<f32, 2>::new(padding, stride, dilation, groups, mode, math_type)?; |
| 731 | + /// |
| 732 | + /// let input_desc = TensorDescriptor::<f32, _, 4>::new([3, 2, 5, 5], NCHW)?; |
| 733 | + /// let filter_desc = FilterDescriptor::<f32, _, 4>::new([3, 2, 2, 2], NCHW)?; |
| 734 | + /// let output_desc = TensorDescriptor::<f32, _, 4>::new([3, 3, 4, 4], NCHW)?; |
| 735 | + /// |
| 736 | + /// let algo = ImplicitPrecompGemm; |
| 737 | + /// |
| 738 | + /// let size = ctx.get_convolution_forward_workspace_size( |
| 739 | + /// &input_desc, |
| 740 | + /// &filter_desc, |
| 741 | + /// &output_desc, |
| 742 | + /// &conv_desc, |
| 743 | + /// &algo, |
| 744 | + /// )?; |
| 745 | + /// |
| 746 | + /// let workspace: DeviceBuffer<f32> = unsafe { DeviceBuffer::uninitialized(size)? }; |
| 747 | + /// |
| 748 | + /// # Ok(()) |
| 749 | + /// # } |
| 750 | + /// ``` |
| 751 | + pub fn get_convolution_forward_workspace_size< |
| 752 | + InType, |
| 753 | + InFmt, |
| 754 | + FilterType, |
| 755 | + FilterFmt, |
| 756 | + CompType, |
| 757 | + OutType, |
| 758 | + OutFmt, |
| 759 | + Algo, |
| 760 | + const D: usize, |
| 761 | + const N: usize, |
| 762 | + >( |
| 763 | + &self, |
| 764 | + x_desc: &TensorDescriptor<InType, InFmt, D>, |
| 765 | + w_desc: &FilterDescriptor<FilterType, FilterFmt, D>, |
| 766 | + y_desc: &TensorDescriptor<OutType, OutFmt, D>, |
| 767 | + conv_desc: &ConvolutionDescriptor<CompType, N>, |
| 768 | + algo: &Algo, |
| 769 | + ) -> Result<usize, CudnnError> |
| 770 | + where |
| 771 | + InType: DataType, |
| 772 | + InFmt: TensorFormat + SupportedType<InType>, |
| 773 | + FilterType: DataType, |
| 774 | + FilterFmt: TensorFormat + SupportedType<FilterType>, |
| 775 | + CompType: DataType, |
| 776 | + OutType: DataType, |
| 777 | + OutFmt: TensorFormat + SupportedType<OutType>, |
| 778 | + Algo: ConvolutionFwdAlgo |
| 779 | + + SupportedConvFwd<InType, InFmt, FilterType, FilterFmt, CompType, OutType, OutFmt, D, N>, |
| 780 | + { |
| 781 | + let mut size = MaybeUninit::uninit(); |
| 782 | + |
| 783 | + unsafe { |
| 784 | + sys::cudnnGetConvolutionForwardWorkspaceSize( |
| 785 | + self.raw, |
| 786 | + x_desc.raw, |
| 787 | + w_desc.raw, |
| 788 | + conv_desc.raw, |
| 789 | + y_desc.raw, |
| 790 | + algo.into_raw(), |
| 791 | + size.as_mut_ptr(), |
| 792 | + ) |
| 793 | + .into_result()?; |
| 794 | + |
| 795 | + Ok(size.assume_init()) |
| 796 | + } |
| 797 | + } |
552 | 798 | }
|
553 | 799 |
|
554 | 800 | impl Drop for CudnnContext {
|
|
0 commit comments