@@ -15,75 +15,71 @@ limitations under the License.
1515
1616#include " gmm.h"
1717
18- py::tuple init ()
19- {
20- torch::Tensor gmm_tensor = torch::zeros ({GMM_COUNT, GMM_COMPONENT_COUNT}, torch::dtype (torch::kFloat32 ).device (torch::kCUDA ));
21- torch::Tensor scratch_tensor = torch::empty ({1 }, torch::dtype (torch::kFloat32 ).device (torch::kCUDA ));
22- return py::make_tuple (gmm_tensor, scratch_tensor);
18+ py::tuple init () {
19+ torch::Tensor gmm_tensor =
20+ torch::zeros ({GMM_COUNT, GMM_COMPONENT_COUNT}, torch::dtype (torch::kFloat32 ).device (torch::kCUDA ));
21+ torch::Tensor scratch_tensor = torch::empty ({1 }, torch::dtype (torch::kFloat32 ).device (torch::kCUDA ));
22+ return py::make_tuple (gmm_tensor, scratch_tensor);
2323}
2424
25- void learn (torch::Tensor gmm_tensor, torch::Tensor scratch_tensor, torch::Tensor input_tensor, torch::Tensor label_tensor)
26- {
27- c10::DeviceType device_type = input_tensor. device (). type ();
28-
29- unsigned int batch_count = input_tensor. size ( 0 );
30- unsigned int element_count = input_tensor.stride ( 1 );
31-
32- unsigned int scratch_size = batch_count * (element_count + GMM_COMPONENT_COUNT * GMM_COUNT * (element_count / ( 32 * 32 )) );
33-
34- if (scratch_tensor. size ( 0 ) < scratch_size)
35- {
36- scratch_tensor. resize_ ({scratch_size} );
37- }
38-
39- float * gmm = gmm_tensor. data_ptr < float >( );
40- float * scratch = scratch_tensor. data_ptr < float >();
41- float * input = input_tensor. data_ptr < float >();
42- int * labels = label_tensor .data_ptr <int >();
43-
44- if (device_type == torch:: kCUDA )
45- {
46- learn_cuda (input, labels, gmm, scratch, batch_count, element_count);
47- }
48- else
49- {
50- learn_cpu (input, labels, gmm, scratch, batch_count, element_count);
51- }
25+ void learn (
26+ torch::Tensor gmm_tensor,
27+ torch::Tensor scratch_tensor,
28+ torch::Tensor input_tensor,
29+ torch::Tensor label_tensor) {
30+ c10::DeviceType device_type = input_tensor.device (). type ( );
31+
32+ unsigned int batch_count = input_tensor. size ( 0 );
33+ unsigned int element_count = input_tensor. stride ( 1 );
34+
35+ unsigned int scratch_size =
36+ batch_count * (element_count + GMM_COMPONENT_COUNT * GMM_COUNT * (element_count / ( 32 * 32 )) );
37+
38+ if (scratch_tensor. size ( 0 ) < scratch_size) {
39+ scratch_tensor. resize_ ({scratch_size} );
40+ }
41+
42+ float * gmm = gmm_tensor .data_ptr <float >();
43+ float * scratch = scratch_tensor. data_ptr < float >();
44+ float * input = input_tensor. data_ptr < float >();
45+ int * labels = label_tensor. data_ptr < int >();
46+
47+ if (device_type == torch:: kCUDA ) {
48+ learn_cuda (input, labels, gmm, scratch, batch_count, element_count);
49+ } else {
50+ learn_cpu (input, labels, gmm, scratch, batch_count, element_count);
51+ }
5252}
5353
54- torch::Tensor apply (torch::Tensor gmm_tensor, torch::Tensor input_tensor)
55- {
56- c10::DeviceType device_type = input_tensor.device ().type ();
57-
58- unsigned int dim = input_tensor.dim ();
59- unsigned int batch_count = input_tensor.size (0 );
60- unsigned int element_count = input_tensor.stride (1 );
61-
62- long int * output_size = new long int [dim];
63- memcpy (output_size, input_tensor.sizes ().data (), dim * sizeof (long int ));
64- output_size[1 ] = MIXTURE_COUNT;
65- torch::Tensor output_tensor = torch::empty (c10::IntArrayRef (output_size, dim), torch::dtype (torch::kFloat32 ).device (device_type));
66- delete output_size;
67-
68- const float * gmm = gmm_tensor.data_ptr <float >();
69- const float * input = input_tensor.data_ptr <float >();
70- float * output = output_tensor.data_ptr <float >();
71-
72- if (device_type == torch::kCUDA )
73- {
74- apply_cuda (gmm, input, output, batch_count, element_count);
75- }
76- else
77- {
78- apply_cpu (gmm, input, output, batch_count, element_count);
79- }
80-
81- return output_tensor;
54+ torch::Tensor apply (torch::Tensor gmm_tensor, torch::Tensor input_tensor) {
55+ c10::DeviceType device_type = input_tensor.device ().type ();
56+
57+ unsigned int dim = input_tensor.dim ();
58+ unsigned int batch_count = input_tensor.size (0 );
59+ unsigned int element_count = input_tensor.stride (1 );
60+
61+ long int * output_size = new long int [dim];
62+ memcpy (output_size, input_tensor.sizes ().data (), dim * sizeof (long int ));
63+ output_size[1 ] = MIXTURE_COUNT;
64+ torch::Tensor output_tensor =
65+ torch::empty (c10::IntArrayRef (output_size, dim), torch::dtype (torch::kFloat32 ).device (device_type));
66+ delete output_size;
67+
68+ const float * gmm = gmm_tensor.data_ptr <float >();
69+ const float * input = input_tensor.data_ptr <float >();
70+ float * output = output_tensor.data_ptr <float >();
71+
72+ if (device_type == torch::kCUDA ) {
73+ apply_cuda (gmm, input, output, batch_count, element_count);
74+ } else {
75+ apply_cpu (gmm, input, output, batch_count, element_count);
76+ }
77+
78+ return output_tensor;
8279}
8380
84- PYBIND11_MODULE (TORCH_EXTENSION_NAME, m)
85- {
86- m.def (" init" , torch::wrap_pybind_function (init));
87- m.def (" learn" , torch::wrap_pybind_function (learn));
88- m.def (" apply" , torch::wrap_pybind_function (apply));
81+ PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
82+ m.def (" init" , torch::wrap_pybind_function (init));
83+ m.def (" learn" , torch::wrap_pybind_function (learn));
84+ m.def (" apply" , torch::wrap_pybind_function (apply));
8985}
0 commit comments