66 It is used to extract allreduce bcol functions where the arrhythmetics has to be done*/
77
88#include "ompi/datatype/ompi_datatype.h"
9+ #include "ompi/datatype/ompi_datatype_internal.h"
910#include "ompi/mca/op/op.h"
1011#include "hcoll/api/hcoll_dte.h"
12+ extern int hcoll_type_attr_keyval ;
1113
1214/*to keep this at hand: Ids of the basic opal_datatypes:
1315#define OPAL_DATATYPE_INT1 4
3133total 15 types
3234*/
3335
34-
35-
36- static dte_data_representation_t * ompi_datatype_2_dte_data_rep [OPAL_DATATYPE_MAX_PREDEFINED ] = {
36+ static dte_data_representation_t * ompi_datatype_2_dte_data_rep [OMPI_DATATYPE_MAX_PREDEFINED ] = {
3737 & DTE_ZERO , /*OPAL_DATATYPE_LOOP 0 */
3838 & DTE_ZERO , /*OPAL_DATATYPE_END_LOOP 1 */
3939 & DTE_ZERO , /*OPAL_DATATYPE_LB 2 */
@@ -53,34 +53,113 @@ static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX
5353 & DTE_FLOAT64 , /*OPAL_DATATYPE_FLOAT8 16 */
5454 & DTE_FLOAT96 , /*OPAL_DATATYPE_FLOAT12 17 */
5555 & DTE_FLOAT128 , /*OPAL_DATATYPE_FLOAT16 18 */
56- #if defined(DTE_FLOAT32_COMPLEX ) && defined ( DTE_FLOAT64_COMPLEX )
56+ #if defined(DTE_FLOAT32_COMPLEX )
5757 & DTE_FLOAT32_COMPLEX , /*OPAL_DATATYPE_COMPLEX8 19 */
58- & DTE_FLOAT64_COMPLEX , /*OPAL_DATATYPE_COMPLEX16 20 */
5958#else
60- & DTE_ZERO , /*OPAL_DATATYPE_COMPLEX8 19 */
61- & DTE_ZERO , /*OPAL_DATATYPE_COMPLEX16 20 */
59+ & DTE_ZERO ,
60+ #endif
61+ #if defined(DTE_FLOAT64_COMPLEX )
62+ & DTE_FLOAT64_COMPLEX , /*OPAL_DATATYPE_COMPLEX32 20 */
63+ #else
64+ & DTE_ZERO ,
65+ #endif
66+ #if defined(DTE_FLOAT128_COMPLEX )
67+ & DTE_FLOAT128_COMPLEX , /*OPAL_DATATYPE_COMPLEX64 21 */
68+ #else
69+ & DTE_ZERO ,
6270#endif
63- & DTE_ZERO , /*OPAL_DATATYPE_COMPLEX32 21 */
6471 & DTE_ZERO , /*OPAL_DATATYPE_BOOL 22 */
6572 & DTE_ZERO , /*OPAL_DATATYPE_WCHAR 23 */
6673 & DTE_ZERO /*OPAL_DATATYPE_UNAVAILABLE 24 */
6774};
6875
69- static dte_data_representation_t ompi_dtype_2_dte_dtype (ompi_datatype_t * dtype ){
76+ enum {
77+ TRY_FIND_DERIVED ,
78+ NO_DERIVED
79+ };
80+
81+
82+ #if HCOLL_API >= HCOLL_VERSION (3 ,6 )
83+ static inline
84+ int hcoll_map_derived_type (ompi_datatype_t * dtype , dte_data_representation_t * new_dte )
85+ {
86+ int rc ;
87+ if (NULL == dtype -> args ) {
88+ /* predefined type, shouldn't call this */
89+ return OMPI_SUCCESS ;
90+ }
91+ rc = hcoll_create_mpi_type ((void * )dtype , new_dte );
92+ return rc == HCOLL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR ;
93+ }
94+
95+ static dte_data_representation_t find_derived_mapping (ompi_datatype_t * dtype ){
96+ dte_data_representation_t dte = DTE_ZERO ;
97+ mca_coll_hcoll_dtype_t * hcoll_dtype ;
98+ if (mca_coll_hcoll_component .derived_types_support_enabled ) {
99+ int map_found = 0 ;
100+ ompi_attr_get_c (dtype -> d_keyhash , hcoll_type_attr_keyval ,
101+ (void * * )& hcoll_dtype , & map_found );
102+ if (!map_found )
103+ hcoll_map_derived_type (dtype , & dte );
104+ else
105+ dte = hcoll_dtype -> type ;
106+ }
107+
108+ return dte ;
109+ }
110+
111+
112+
113+ static inline dte_data_representation_t
114+ ompi_predefined_derived_2_hcoll (int ompi_id ) {
115+ switch (ompi_id ) {
116+ case OMPI_DATATYPE_MPI_FLOAT_INT :
117+ return DTE_FLOAT_INT ;
118+ case OMPI_DATATYPE_MPI_DOUBLE_INT :
119+ return DTE_DOUBLE_INT ;
120+ case OMPI_DATATYPE_MPI_LONG_INT :
121+ return DTE_LONG_INT ;
122+ case OMPI_DATATYPE_MPI_SHORT_INT :
123+ return DTE_SHORT_INT ;
124+ case OMPI_DATATYPE_MPI_LONG_DOUBLE_INT :
125+ return DTE_LONG_DOUBLE_INT ;
126+ case OMPI_DATATYPE_MPI_2INT :
127+ return DTE_2INT ;
128+ default :
129+ break ;
130+ }
131+ return DTE_ZERO ;
132+ }
133+ #endif
134+
135+ static dte_data_representation_t
136+ ompi_dtype_2_hcoll_dtype ( ompi_datatype_t * dtype ,
137+ const int mode )
138+ {
70139 int ompi_type_id = dtype -> id ;
71140 int opal_type_id = dtype -> super .id ;
72- dte_data_representation_t dte_data_rep ;
73- if (!(dtype -> super .flags & OPAL_DATATYPE_FLAG_NO_GAPS )) {
74- ompi_type_id = -1 ;
141+ dte_data_representation_t dte_data_rep = DTE_ZERO ;
142+
143+ if (ompi_type_id < OMPI_DATATYPE_MPI_MAX_PREDEFINED ) {
144+ if (opal_type_id > 0 && opal_type_id < OPAL_DATATYPE_MAX_PREDEFINED ) {
145+ dte_data_rep = * ompi_datatype_2_dte_data_rep [opal_type_id ];
146+ }
147+ #if HCOLL_API >= HCOLL_VERSION (3 ,6 )
148+ else if (TRY_FIND_DERIVED == mode ){
149+ dte_data_rep = ompi_predefined_derived_2_hcoll (ompi_type_id );
150+ }
151+ } else {
152+ if (TRY_FIND_DERIVED == mode )
153+ dte_data_rep = find_derived_mapping (dtype );
154+ #endif
75155 }
76- if (OPAL_UNLIKELY ( ompi_type_id < 0 ||
77- ompi_type_id >= OPAL_DATATYPE_MAX_PREDEFINED )) {
156+ if (HCOL_DTE_IS_ZERO ( dte_data_rep ) && TRY_FIND_DERIVED == mode &&
157+ ! mca_coll_hcoll_component . hcoll_datatype_fallback ) {
78158 dte_data_rep = DTE_ZERO ;
79159 dte_data_rep .rep .in_line_rep .data_handle .in_line .in_line = 0 ;
80160 dte_data_rep .rep .in_line_rep .data_handle .pointer_to_handle = (uint64_t ) & dtype -> super ;
81- return dte_data_rep ;
82161 }
83- return * ompi_datatype_2_dte_data_rep [ opal_type_id ] ;
162+ return dte_data_rep ;
84163}
85164
86165static hcoll_dte_op_t * ompi_op_2_hcoll_op [OMPI_OP_BASE_FORTRAN_OP_MAX + 1 ] = {
@@ -108,4 +187,27 @@ static hcoll_dte_op_t* ompi_op_2_hcolrte_op(ompi_op_t *op) {
108187 return ompi_op_2_hcoll_op [op -> o_f_to_c_index ];
109188}
110189
190+
191+ #if HCOLL_API >= HCOLL_VERSION (3 ,6 )
192+ static int hcoll_type_attr_del_fn (MPI_Datatype type , int keyval , void * attr_val , void * extra ) {
193+ int ret = OMPI_SUCCESS ;
194+ mca_coll_hcoll_dtype_t * dtype =
195+ (mca_coll_hcoll_dtype_t * ) attr_val ;
196+
197+ assert (dtype );
198+ if (HCOLL_SUCCESS != (ret = hcoll_dt_destroy (dtype -> type ))) {
199+ HCOL_ERROR ("failed to delete type attr: hcoll_dte_destroy returned %d" ,ret );
200+ return OMPI_ERROR ;
201+ }
202+ OPAL_FREE_LIST_RETURN (& mca_coll_hcoll_component .dtypes ,
203+ & dtype -> super );
204+
205+ return OMPI_SUCCESS ;
206+ }
207+ #else
208+ static int hcoll_type_attr_del_fn (MPI_Datatype type , int keyval , void * attr_val , void * extra ) {
209+ /*Do nothing - it's an old version of hcoll w/o dtypes support */
210+ return OMPI_SUCCESS ;
211+ }
212+ #endif
111213#endif /* COLL_HCOLL_DTYPES_H */
0 commit comments