Skip to content
This repository was archived by the owner on Sep 30, 2022. It is now read-only.

Commit e41cc74

Browse files
Valentin Petrovjladd-mlnx
authored andcommitted
coll/hcoll mpi datatypes support
(cherry picked from commit 3582bba) Conflicts: ompi/mca/coll/hcoll/coll_hcoll_rte.c
1 parent 4860038 commit e41cc74

File tree

6 files changed

+371
-191
lines changed

6 files changed

+371
-191
lines changed

ompi/mca/coll/hcoll/coll_hcoll.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ typedef struct mca_coll_hcoll_ops_t {
4949
int (*hcoll_barrier)(void *);
5050
} mca_coll_hcoll_ops_t;
5151

52+
typedef struct {
53+
opal_free_list_item_t super;
54+
dte_data_representation_t type;
55+
} mca_coll_hcoll_dtype_t;
56+
OBJ_CLASS_DECLARATION(mca_coll_hcoll_dtype_t);
5257

5358
struct mca_coll_hcoll_component_t {
5459
/** Base coll component */
@@ -89,6 +94,8 @@ struct mca_coll_hcoll_component_t {
8994
/* FCA global stuff */
9095
mca_coll_hcoll_ops_t hcoll_ops;
9196
opal_free_list_t requests;
97+
opal_free_list_t dtypes;
98+
int derived_types_support_enabled;
9299
};
93100
typedef struct mca_coll_hcoll_component_t mca_coll_hcoll_component_t;
94101

ompi/mca/coll/hcoll/coll_hcoll_component.c

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "coll_hcoll.h"
1919
#include "opal/mca/installdirs/installdirs.h"
20+
#include "coll_hcoll_dtypes.h"
2021

2122
/*
2223
* Public string showing the coll ompi_hcol component version number
@@ -207,7 +208,15 @@ static int hcoll_register(void)
207208
1,
208209
&mca_coll_hcoll_component.hcoll_datatype_fallback,
209210
0));
210-
211+
#if HCOLL_API >= HCOLL_VERSION(3,6)
212+
CHECK(reg_int("dts",NULL,
213+
"[1|0|] Enable/Disable derived types support",
214+
1,
215+
&mca_coll_hcoll_component.derived_types_support_enabled,
216+
0));
217+
#else
218+
mca_coll_hcoll_component.derived_types_support_enabled = 0;
219+
#endif
211220
mca_coll_hcoll_component.compiletime_version = HCOLL_VERNO_STRING;
212221
mca_base_component_var_register(&mca_coll_hcoll_component.super.collm_version,
213222
MCA_COMPILETIME_VER,
@@ -278,7 +287,7 @@ static int hcoll_close(void)
278287

279288
HCOL_VERBOSE(5,"HCOLL FINALIZE");
280289
rc = hcoll_finalize();
281-
290+
OBJ_DESTRUCT(&cm->dtypes);
282291
opal_progress_unregister(mca_coll_hcoll_progress);
283292
if (HCOLL_SUCCESS != rc){
284293
HCOL_VERBOSE(1,"Hcol library finalize failed");

ompi/mca/coll/hcoll/coll_hcoll_dtypes.h

Lines changed: 118 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
It is used to extract allreduce bcol functions where the arrhythmetics has to be done*/
77

88
#include "ompi/datatype/ompi_datatype.h"
9+
#include "ompi/datatype/ompi_datatype_internal.h"
910
#include "ompi/mca/op/op.h"
1011
#include "hcoll/api/hcoll_dte.h"
12+
extern int hcoll_type_attr_keyval;
1113

1214
/*to keep this at hand: Ids of the basic opal_datatypes:
1315
#define OPAL_DATATYPE_INT1 4
@@ -31,9 +33,7 @@
3133
total 15 types
3234
*/
3335

34-
35-
36-
static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX_PREDEFINED] = {
36+
static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OMPI_DATATYPE_MAX_PREDEFINED] = {
3737
&DTE_ZERO, /*OPAL_DATATYPE_LOOP 0 */
3838
&DTE_ZERO, /*OPAL_DATATYPE_END_LOOP 1 */
3939
&DTE_ZERO, /*OPAL_DATATYPE_LB 2 */
@@ -53,34 +53,113 @@ static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX
5353
&DTE_FLOAT64, /*OPAL_DATATYPE_FLOAT8 16 */
5454
&DTE_FLOAT96, /*OPAL_DATATYPE_FLOAT12 17 */
5555
&DTE_FLOAT128, /*OPAL_DATATYPE_FLOAT16 18 */
56-
#if defined(DTE_FLOAT32_COMPLEX) && defined(DTE_FLOAT64_COMPLEX)
56+
#if defined(DTE_FLOAT32_COMPLEX)
5757
&DTE_FLOAT32_COMPLEX, /*OPAL_DATATYPE_COMPLEX8 19 */
58-
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX16 20 */
5958
#else
60-
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX8 19 */
61-
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX16 20 */
59+
&DTE_ZERO,
60+
#endif
61+
#if defined(DTE_FLOAT64_COMPLEX)
62+
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX32 20 */
63+
#else
64+
&DTE_ZERO,
65+
#endif
66+
#if defined(DTE_FLOAT128_COMPLEX)
67+
&DTE_FLOAT128_COMPLEX, /*OPAL_DATATYPE_COMPLEX64 21 */
68+
#else
69+
&DTE_ZERO,
6270
#endif
63-
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX32 21 */
6471
&DTE_ZERO, /*OPAL_DATATYPE_BOOL 22 */
6572
&DTE_ZERO, /*OPAL_DATATYPE_WCHAR 23 */
6673
&DTE_ZERO /*OPAL_DATATYPE_UNAVAILABLE 24 */
6774
};
6875

69-
static dte_data_representation_t ompi_dtype_2_dte_dtype(ompi_datatype_t *dtype){
76+
enum {
77+
TRY_FIND_DERIVED,
78+
NO_DERIVED
79+
};
80+
81+
82+
#if HCOLL_API >= HCOLL_VERSION(3,6)
83+
static inline
84+
int hcoll_map_derived_type(ompi_datatype_t *dtype, dte_data_representation_t *new_dte)
85+
{
86+
int rc;
87+
if (NULL == dtype->args) {
88+
/* predefined type, shouldn't call this */
89+
return OMPI_SUCCESS;
90+
}
91+
rc = hcoll_create_mpi_type((void*)dtype, new_dte);
92+
return rc == HCOLL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR;
93+
}
94+
95+
static dte_data_representation_t find_derived_mapping(ompi_datatype_t *dtype){
96+
dte_data_representation_t dte = DTE_ZERO;
97+
mca_coll_hcoll_dtype_t *hcoll_dtype;
98+
if (mca_coll_hcoll_component.derived_types_support_enabled) {
99+
int map_found = 0;
100+
ompi_attr_get_c(dtype->d_keyhash, hcoll_type_attr_keyval,
101+
(void**)&hcoll_dtype, &map_found);
102+
if (!map_found)
103+
hcoll_map_derived_type(dtype, &dte);
104+
else
105+
dte = hcoll_dtype->type;
106+
}
107+
108+
return dte;
109+
}
110+
111+
112+
113+
static inline dte_data_representation_t
114+
ompi_predefined_derived_2_hcoll(int ompi_id) {
115+
switch(ompi_id) {
116+
case OMPI_DATATYPE_MPI_FLOAT_INT:
117+
return DTE_FLOAT_INT;
118+
case OMPI_DATATYPE_MPI_DOUBLE_INT:
119+
return DTE_DOUBLE_INT;
120+
case OMPI_DATATYPE_MPI_LONG_INT:
121+
return DTE_LONG_INT;
122+
case OMPI_DATATYPE_MPI_SHORT_INT:
123+
return DTE_SHORT_INT;
124+
case OMPI_DATATYPE_MPI_LONG_DOUBLE_INT:
125+
return DTE_LONG_DOUBLE_INT;
126+
case OMPI_DATATYPE_MPI_2INT:
127+
return DTE_2INT;
128+
default:
129+
break;
130+
}
131+
return DTE_ZERO;
132+
}
133+
#endif
134+
135+
static dte_data_representation_t
136+
ompi_dtype_2_hcoll_dtype( ompi_datatype_t *dtype,
137+
const int mode)
138+
{
70139
int ompi_type_id = dtype->id;
71140
int opal_type_id = dtype->super.id;
72-
dte_data_representation_t dte_data_rep;
73-
if (!(dtype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS)) {
74-
ompi_type_id = -1;
141+
dte_data_representation_t dte_data_rep = DTE_ZERO;
142+
143+
if (ompi_type_id < OMPI_DATATYPE_MPI_MAX_PREDEFINED) {
144+
if (opal_type_id > 0 && opal_type_id < OPAL_DATATYPE_MAX_PREDEFINED) {
145+
dte_data_rep = *ompi_datatype_2_dte_data_rep[opal_type_id];
146+
}
147+
#if HCOLL_API >= HCOLL_VERSION(3,6)
148+
else if (TRY_FIND_DERIVED == mode){
149+
dte_data_rep = ompi_predefined_derived_2_hcoll(ompi_type_id);
150+
}
151+
} else {
152+
if (TRY_FIND_DERIVED == mode)
153+
dte_data_rep = find_derived_mapping(dtype);
154+
#endif
75155
}
76-
if (OPAL_UNLIKELY( ompi_type_id < 0 ||
77-
ompi_type_id >= OPAL_DATATYPE_MAX_PREDEFINED)){
156+
if (HCOL_DTE_IS_ZERO(dte_data_rep) && TRY_FIND_DERIVED == mode &&
157+
!mca_coll_hcoll_component.hcoll_datatype_fallback) {
78158
dte_data_rep = DTE_ZERO;
79159
dte_data_rep.rep.in_line_rep.data_handle.in_line.in_line = 0;
80160
dte_data_rep.rep.in_line_rep.data_handle.pointer_to_handle = (uint64_t ) &dtype->super;
81-
return dte_data_rep;
82161
}
83-
return *ompi_datatype_2_dte_data_rep[opal_type_id];
162+
return dte_data_rep;
84163
}
85164

86165
static hcoll_dte_op_t* ompi_op_2_hcoll_op[OMPI_OP_BASE_FORTRAN_OP_MAX + 1] = {
@@ -108,4 +187,27 @@ static hcoll_dte_op_t* ompi_op_2_hcolrte_op(ompi_op_t *op) {
108187
return ompi_op_2_hcoll_op[op->o_f_to_c_index];
109188
}
110189

190+
191+
#if HCOLL_API >= HCOLL_VERSION(3,6)
192+
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
193+
int ret = OMPI_SUCCESS;
194+
mca_coll_hcoll_dtype_t *dtype =
195+
(mca_coll_hcoll_dtype_t*) attr_val;
196+
197+
assert(dtype);
198+
if (HCOLL_SUCCESS != (ret = hcoll_dt_destroy(dtype->type))) {
199+
HCOL_ERROR("failed to delete type attr: hcoll_dte_destroy returned %d",ret);
200+
return OMPI_ERROR;
201+
}
202+
opal_free_list_return(&mca_coll_hcoll_component.dtypes,
203+
&dtype->super);
204+
205+
return OMPI_SUCCESS;
206+
}
207+
#else
208+
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
209+
/*Do nothing - it's an old version of hcoll w/o dtypes support */
210+
return OMPI_SUCCESS;
211+
}
212+
#endif
111213
#endif /* COLL_HCOLL_DTYPES_H */

ompi/mca/coll/hcoll/coll_hcoll_module.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
#include "ompi_config.h"
1212
#include "coll_hcoll.h"
13+
#include "coll_hcoll_dtypes.h"
1314

1415
int hcoll_comm_attr_keyval;
16+
int hcoll_type_attr_keyval;
1517

1618
/*
1719
* Initial query function that is invoked during MPI_INIT, allowing
@@ -240,6 +242,10 @@ int mca_coll_hcoll_progress(void)
240242
}
241243

242244

245+
OBJ_CLASS_INSTANCE(mca_coll_hcoll_dtype_t,
246+
opal_free_list_item_t,
247+
NULL,NULL);
248+
243249
/*
244250
* Invoked when there's a new communicator that has been created.
245251
* Look at the communicator and decide which set of functions and
@@ -317,6 +323,24 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
317323
HCOL_ERROR("Hcol comm keyval create failed");
318324
return NULL;
319325
}
326+
327+
if (mca_coll_hcoll_component.derived_types_support_enabled) {
328+
copy_fn.attr_datatype_copy_fn = (MPI_Type_internal_copy_attr_function *) MPI_TYPE_NULL_COPY_FN;
329+
del_fn.attr_datatype_delete_fn = hcoll_type_attr_del_fn;
330+
err = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn, &hcoll_type_attr_keyval, NULL ,0, NULL);
331+
if (OMPI_SUCCESS != err) {
332+
cm->hcoll_enable = 0;
333+
hcoll_finalize();
334+
opal_progress_unregister(mca_coll_hcoll_progress);
335+
HCOL_ERROR("Hcol type keyval create failed");
336+
return NULL;
337+
}
338+
}
339+
OBJ_CONSTRUCT(&cm->dtypes, opal_free_list_t);
340+
opal_free_list_init(&cm->dtypes, sizeof(mca_coll_hcoll_dtype_t),
341+
8, OBJ_CLASS(mca_coll_hcoll_dtype_t), 0, 0,
342+
32, -1, 32, NULL, 0, NULL, NULL, NULL);
343+
320344
}
321345

322346
hcoll_module = OBJ_NEW(mca_coll_hcoll_module_t);

0 commit comments

Comments
 (0)