66  It is used to extract allreduce bcol functions where the arrhythmetics has to be done*/ 
77
88#include  "ompi/datatype/ompi_datatype.h" 
9+ #include  "ompi/datatype/ompi_datatype_internal.h" 
910#include  "ompi/mca/op/op.h" 
1011#include  "hcoll/api/hcoll_dte.h" 
12+ extern  int  hcoll_type_attr_keyval ;
1113
1214/*to keep this at hand: Ids of the basic opal_datatypes: 
1315#define OPAL_DATATYPE_INT1           4 
3133total 15 types 
3234*/ 
3335
34- 
35- 
36- static  dte_data_representation_t *  ompi_datatype_2_dte_data_rep [OPAL_DATATYPE_MAX_PREDEFINED ] =  {
36+ static  dte_data_representation_t *  ompi_datatype_2_dte_data_rep [OMPI_DATATYPE_MAX_PREDEFINED ] =  {
3737    & DTE_ZERO ,                  /*OPAL_DATATYPE_LOOP           0 */ 
3838    & DTE_ZERO ,                  /*OPAL_DATATYPE_END_LOOP       1 */ 
3939    & DTE_ZERO ,                  /*OPAL_DATATYPE_LB             2 */ 
@@ -53,34 +53,113 @@ static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX
5353    & DTE_FLOAT64 ,               /*OPAL_DATATYPE_FLOAT8         16 */ 
5454    & DTE_FLOAT96 ,               /*OPAL_DATATYPE_FLOAT12        17 */ 
5555    & DTE_FLOAT128 ,              /*OPAL_DATATYPE_FLOAT16        18 */ 
56- #if  defined(DTE_FLOAT32_COMPLEX )  &&   defined ( DTE_FLOAT64_COMPLEX ) 
56+ #if  defined(DTE_FLOAT32_COMPLEX )
5757    & DTE_FLOAT32_COMPLEX ,       /*OPAL_DATATYPE_COMPLEX8       19 */ 
58-     & DTE_FLOAT64_COMPLEX ,       /*OPAL_DATATYPE_COMPLEX16      20 */ 
5958#else 
60-     & DTE_ZERO ,                  /*OPAL_DATATYPE_COMPLEX8       19 */ 
61-     & DTE_ZERO ,                  /*OPAL_DATATYPE_COMPLEX16      20 */ 
59+     & DTE_ZERO ,
60+ #endif 
61+ #if  defined(DTE_FLOAT64_COMPLEX )
62+     & DTE_FLOAT64_COMPLEX ,       /*OPAL_DATATYPE_COMPLEX32      20 */ 
63+ #else 
64+     & DTE_ZERO ,
65+ #endif 
66+ #if  defined(DTE_FLOAT128_COMPLEX )
67+     & DTE_FLOAT128_COMPLEX ,       /*OPAL_DATATYPE_COMPLEX64     21 */ 
68+ #else 
69+     & DTE_ZERO ,
6270#endif 
63-     & DTE_ZERO ,                  /*OPAL_DATATYPE_COMPLEX32      21 */ 
6471    & DTE_ZERO ,                  /*OPAL_DATATYPE_BOOL           22 */ 
6572    & DTE_ZERO ,                  /*OPAL_DATATYPE_WCHAR          23 */ 
6673    & DTE_ZERO                    /*OPAL_DATATYPE_UNAVAILABLE    24 */ 
6774};
6875
69- static  dte_data_representation_t  ompi_dtype_2_dte_dtype (ompi_datatype_t  * dtype ){
76+ enum  {
77+     TRY_FIND_DERIVED ,
78+     NO_DERIVED 
79+ };
80+ 
81+ 
82+ #if  HCOLL_API  >= HCOLL_VERSION (3 ,6 )
83+ static  inline 
84+ int  hcoll_map_derived_type (ompi_datatype_t  * dtype , dte_data_representation_t  * new_dte )
85+ {
86+     int  rc ;
87+     if  (NULL  ==  dtype -> args ) {
88+         /* predefined type, shouldn't call this */ 
89+         return  OMPI_SUCCESS ;
90+     }
91+     rc  =  hcoll_create_mpi_type ((void * )dtype , new_dte );
92+     return  rc  ==  HCOLL_SUCCESS  ? OMPI_SUCCESS  : OMPI_ERROR ;
93+ }
94+ 
95+ static  dte_data_representation_t  find_derived_mapping (ompi_datatype_t  * dtype ){
96+     dte_data_representation_t  dte  =  DTE_ZERO ;
97+     mca_coll_hcoll_dtype_t  * hcoll_dtype ;
98+     if  (mca_coll_hcoll_component .derived_types_support_enabled ) {
99+         int  map_found  =  0 ;
100+         ompi_attr_get_c (dtype -> d_keyhash , hcoll_type_attr_keyval ,
101+                         (void * * )& hcoll_dtype , & map_found );
102+         if  (!map_found )
103+             hcoll_map_derived_type (dtype , & dte );
104+         else 
105+             dte  =  hcoll_dtype -> type ;
106+     }
107+ 
108+     return  dte ;
109+ }
110+ 
111+ 
112+ 
113+ static  inline   dte_data_representation_t 
114+ ompi_predefined_derived_2_hcoll (int  ompi_id ) {
115+     switch (ompi_id ) {
116+     case  OMPI_DATATYPE_MPI_FLOAT_INT :
117+         return  DTE_FLOAT_INT ;
118+     case  OMPI_DATATYPE_MPI_DOUBLE_INT :
119+         return  DTE_DOUBLE_INT ;
120+     case  OMPI_DATATYPE_MPI_LONG_INT :
121+         return  DTE_LONG_INT ;
122+     case  OMPI_DATATYPE_MPI_SHORT_INT :
123+         return  DTE_SHORT_INT ;
124+     case  OMPI_DATATYPE_MPI_LONG_DOUBLE_INT :
125+         return  DTE_LONG_DOUBLE_INT ;
126+     case  OMPI_DATATYPE_MPI_2INT :
127+         return  DTE_2INT ;
128+     default :
129+         break ;
130+     }
131+     return  DTE_ZERO ;
132+ }
133+ #endif 
134+ 
135+ static  dte_data_representation_t 
136+ ompi_dtype_2_hcoll_dtype ( ompi_datatype_t  * dtype ,
137+                           const  int  mode )
138+ {
70139    int  ompi_type_id  =  dtype -> id ;
71140    int  opal_type_id  =  dtype -> super .id ;
72-     dte_data_representation_t  dte_data_rep ;
73-     if  (!(dtype -> super .flags  &  OPAL_DATATYPE_FLAG_NO_GAPS )) {
74-         ompi_type_id  =  -1 ;
141+     dte_data_representation_t  dte_data_rep  =  DTE_ZERO ;
142+ 
143+     if  (ompi_type_id  <  OMPI_DATATYPE_MPI_MAX_PREDEFINED ) {
144+         if  (opal_type_id  >  0  &&  opal_type_id  <  OPAL_DATATYPE_MAX_PREDEFINED ) {
145+             dte_data_rep  =   * ompi_datatype_2_dte_data_rep [opal_type_id ];
146+         }
147+ #if  HCOLL_API  >= HCOLL_VERSION (3 ,6 )
148+         else  if  (TRY_FIND_DERIVED  ==  mode ){
149+             dte_data_rep  =   ompi_predefined_derived_2_hcoll (ompi_type_id );
150+         }
151+     } else  {
152+         if  (TRY_FIND_DERIVED  ==  mode )
153+             dte_data_rep  =   find_derived_mapping (dtype );
154+ #endif 
75155    }
76-     if  (OPAL_UNLIKELY (  ompi_type_id   <   0   || 
77-                         ompi_type_id  >=  OPAL_DATATYPE_MAX_PREDEFINED )) {
156+     if  (HCOL_DTE_IS_ZERO ( dte_data_rep )  &&   TRY_FIND_DERIVED   ==   mode   && 
157+         ! mca_coll_hcoll_component . hcoll_datatype_fallback )  {
78158        dte_data_rep  =  DTE_ZERO ;
79159        dte_data_rep .rep .in_line_rep .data_handle .in_line .in_line  =  0 ;
80160        dte_data_rep .rep .in_line_rep .data_handle .pointer_to_handle  =  (uint64_t  ) & dtype -> super ;
81-         return  dte_data_rep ;
82161    }
83-     return  * ompi_datatype_2_dte_data_rep [ opal_type_id ] ;
162+     return  dte_data_rep ;
84163}
85164
86165static  hcoll_dte_op_t *  ompi_op_2_hcoll_op [OMPI_OP_BASE_FORTRAN_OP_MAX  +  1 ] =  {
@@ -108,4 +187,27 @@ static hcoll_dte_op_t* ompi_op_2_hcolrte_op(ompi_op_t *op) {
108187    return  ompi_op_2_hcoll_op [op -> o_f_to_c_index ];
109188}
110189
190+ 
191+ #if  HCOLL_API  >= HCOLL_VERSION (3 ,6 )
192+ static  int  hcoll_type_attr_del_fn (MPI_Datatype  type , int  keyval , void  * attr_val , void  * extra ) {
193+     int  ret  =  OMPI_SUCCESS ;
194+     mca_coll_hcoll_dtype_t  * dtype  = 
195+         (mca_coll_hcoll_dtype_t * ) attr_val ;
196+ 
197+     assert (dtype );
198+     if  (HCOLL_SUCCESS  !=  (ret  =  hcoll_dt_destroy (dtype -> type ))) {
199+         HCOL_ERROR ("failed to delete type attr: hcoll_dte_destroy returned %d" ,ret );
200+         return  OMPI_ERROR ;
201+     }
202+     OPAL_FREE_LIST_RETURN (& mca_coll_hcoll_component .dtypes ,
203+                           & dtype -> super );
204+ 
205+     return  OMPI_SUCCESS ;
206+ }
207+ #else 
208+ static  int  hcoll_type_attr_del_fn (MPI_Datatype  type , int  keyval , void  * attr_val , void  * extra ) {
209+     /*Do nothing - it's an old version of hcoll w/o dtypes support */ 
210+     return  OMPI_SUCCESS ;
211+ }
212+ #endif 
111213#endif  /* COLL_HCOLL_DTYPES_H */ 
0 commit comments