Skip to content

Commit 0d6ddc5

Browse files
author
rhc54
authored
Merge pull request open-mpi#1291 from vspetrov/hcoll_derived_datatypes
coll/hcoll mpi datatypes support
2 parents 640bcf6 + 58473c5 commit 0d6ddc5

File tree

6 files changed

+375
-190
lines changed

6 files changed

+375
-190
lines changed

ompi/mca/coll/hcoll/coll_hcoll.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ typedef struct mca_coll_hcoll_ops_t {
4747
int (*hcoll_barrier)(void *);
4848
} mca_coll_hcoll_ops_t;
4949

50+
typedef struct {
51+
opal_free_list_item_t super;
52+
dte_data_representation_t type;
53+
} mca_coll_hcoll_dtype_t;
54+
OBJ_CLASS_DECLARATION(mca_coll_hcoll_dtype_t);
5055

5156
struct mca_coll_hcoll_component_t {
5257
/** Base coll component */
@@ -80,8 +85,9 @@ struct mca_coll_hcoll_component_t {
8085

8186
/* FCA global stuff */
8287
mca_coll_hcoll_ops_t hcoll_ops;
83-
8488
ompi_free_list_t requests;
89+
opal_free_list_t dtypes;
90+
int derived_types_support_enabled;
8591
};
8692
typedef struct mca_coll_hcoll_component_t mca_coll_hcoll_component_t;
8793

ompi/mca/coll/hcoll/coll_hcoll_component.c

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "coll_hcoll.h"
1616
#include "opal/mca/installdirs/installdirs.h"
17+
#include "coll_hcoll_dtypes.h"
1718

1819
/*
1920
* Public string showing the coll ompi_hcol component version number
@@ -205,8 +206,15 @@ static int hcoll_register(void)
205206
1,
206207
&mca_coll_hcoll_component.hcoll_datatype_fallback,
207208
0));
208-
209-
209+
#if HCOLL_API >= HCOLL_VERSION(3,6)
210+
CHECK(reg_int("dts",NULL,
211+
"[1|0|] Enable/Disable derived types support",
212+
1,
213+
&mca_coll_hcoll_component.derived_types_support_enabled,
214+
0));
215+
#else
216+
mca_coll_hcoll_component.derived_types_support_enabled = 0;
217+
#endif
210218
return ret;
211219
}
212220

@@ -258,7 +266,7 @@ static int hcoll_close(void)
258266

259267
HCOL_VERBOSE(5,"HCOLL FINALIZE");
260268
rc = hcoll_finalize();
261-
269+
OBJ_DESTRUCT(&cm->dtypes);
262270
opal_progress_unregister(mca_coll_hcoll_progress);
263271
if (HCOLL_SUCCESS != rc){
264272
HCOL_VERBOSE(1,"Hcol library finalize failed");

ompi/mca/coll/hcoll/coll_hcoll_dtypes.h

Lines changed: 118 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
It is used to extract allreduce bcol functions where the arrhythmetics has to be done*/
77

88
#include "ompi/datatype/ompi_datatype.h"
9+
#include "ompi/datatype/ompi_datatype_internal.h"
910
#include "ompi/mca/op/op.h"
1011
#include "hcoll/api/hcoll_dte.h"
12+
extern int hcoll_type_attr_keyval;
1113

1214
/*to keep this at hand: Ids of the basic opal_datatypes:
1315
#define OPAL_DATATYPE_INT1 4
@@ -31,9 +33,7 @@
3133
total 15 types
3234
*/
3335

34-
35-
36-
static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX_PREDEFINED] = {
36+
static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OMPI_DATATYPE_MAX_PREDEFINED] = {
3737
&DTE_ZERO, /*OPAL_DATATYPE_LOOP 0 */
3838
&DTE_ZERO, /*OPAL_DATATYPE_END_LOOP 1 */
3939
&DTE_ZERO, /*OPAL_DATATYPE_LB 2 */
@@ -53,34 +53,113 @@ static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX
5353
&DTE_FLOAT64, /*OPAL_DATATYPE_FLOAT8 16 */
5454
&DTE_FLOAT96, /*OPAL_DATATYPE_FLOAT12 17 */
5555
&DTE_FLOAT128, /*OPAL_DATATYPE_FLOAT16 18 */
56-
#if defined(DTE_FLOAT32_COMPLEX) && defined(DTE_FLOAT64_COMPLEX)
56+
#if defined(DTE_FLOAT32_COMPLEX)
5757
&DTE_FLOAT32_COMPLEX, /*OPAL_DATATYPE_COMPLEX8 19 */
58-
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX16 20 */
5958
#else
60-
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX8 19 */
61-
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX16 20 */
59+
&DTE_ZERO,
60+
#endif
61+
#if defined(DTE_FLOAT64_COMPLEX)
62+
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX32 20 */
63+
#else
64+
&DTE_ZERO,
65+
#endif
66+
#if defined(DTE_FLOAT128_COMPLEX)
67+
&DTE_FLOAT128_COMPLEX, /*OPAL_DATATYPE_COMPLEX64 21 */
68+
#else
69+
&DTE_ZERO,
6270
#endif
63-
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX32 21 */
6471
&DTE_ZERO, /*OPAL_DATATYPE_BOOL 22 */
6572
&DTE_ZERO, /*OPAL_DATATYPE_WCHAR 23 */
6673
&DTE_ZERO /*OPAL_DATATYPE_UNAVAILABLE 24 */
6774
};
6875

69-
static dte_data_representation_t ompi_dtype_2_dte_dtype(ompi_datatype_t *dtype){
76+
enum {
77+
TRY_FIND_DERIVED,
78+
NO_DERIVED
79+
};
80+
81+
82+
#if HCOLL_API >= HCOLL_VERSION(3,6)
83+
static inline
84+
int hcoll_map_derived_type(ompi_datatype_t *dtype, dte_data_representation_t *new_dte)
85+
{
86+
int rc;
87+
if (NULL == dtype->args) {
88+
/* predefined type, shouldn't call this */
89+
return OMPI_SUCCESS;
90+
}
91+
rc = hcoll_create_mpi_type((void*)dtype, new_dte);
92+
return rc == HCOLL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR;
93+
}
94+
95+
static dte_data_representation_t find_derived_mapping(ompi_datatype_t *dtype){
96+
dte_data_representation_t dte = DTE_ZERO;
97+
mca_coll_hcoll_dtype_t *hcoll_dtype;
98+
if (mca_coll_hcoll_component.derived_types_support_enabled) {
99+
int map_found = 0;
100+
ompi_attr_get_c(dtype->d_keyhash, hcoll_type_attr_keyval,
101+
(void**)&hcoll_dtype, &map_found);
102+
if (!map_found)
103+
hcoll_map_derived_type(dtype, &dte);
104+
else
105+
dte = hcoll_dtype->type;
106+
}
107+
108+
return dte;
109+
}
110+
111+
112+
113+
static inline dte_data_representation_t
114+
ompi_predefined_derived_2_hcoll(int ompi_id) {
115+
switch(ompi_id) {
116+
case OMPI_DATATYPE_MPI_FLOAT_INT:
117+
return DTE_FLOAT_INT;
118+
case OMPI_DATATYPE_MPI_DOUBLE_INT:
119+
return DTE_DOUBLE_INT;
120+
case OMPI_DATATYPE_MPI_LONG_INT:
121+
return DTE_LONG_INT;
122+
case OMPI_DATATYPE_MPI_SHORT_INT:
123+
return DTE_SHORT_INT;
124+
case OMPI_DATATYPE_MPI_LONG_DOUBLE_INT:
125+
return DTE_LONG_DOUBLE_INT;
126+
case OMPI_DATATYPE_MPI_2INT:
127+
return DTE_2INT;
128+
default:
129+
break;
130+
}
131+
return DTE_ZERO;
132+
}
133+
#endif
134+
135+
static dte_data_representation_t
136+
ompi_dtype_2_hcoll_dtype( ompi_datatype_t *dtype,
137+
const int mode)
138+
{
70139
int ompi_type_id = dtype->id;
71140
int opal_type_id = dtype->super.id;
72-
dte_data_representation_t dte_data_rep;
73-
if (!(dtype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS)) {
74-
ompi_type_id = -1;
141+
dte_data_representation_t dte_data_rep = DTE_ZERO;
142+
143+
if (ompi_type_id < OMPI_DATATYPE_MPI_MAX_PREDEFINED) {
144+
if (opal_type_id > 0 && opal_type_id < OPAL_DATATYPE_MAX_PREDEFINED) {
145+
dte_data_rep = *ompi_datatype_2_dte_data_rep[opal_type_id];
146+
}
147+
#if HCOLL_API >= HCOLL_VERSION(3,6)
148+
else if (TRY_FIND_DERIVED == mode){
149+
dte_data_rep = ompi_predefined_derived_2_hcoll(ompi_type_id);
150+
}
151+
} else {
152+
if (TRY_FIND_DERIVED == mode)
153+
dte_data_rep = find_derived_mapping(dtype);
154+
#endif
75155
}
76-
if (OPAL_UNLIKELY( ompi_type_id < 0 ||
77-
ompi_type_id >= OPAL_DATATYPE_MAX_PREDEFINED)){
156+
if (HCOL_DTE_IS_ZERO(dte_data_rep) && TRY_FIND_DERIVED == mode &&
157+
!mca_coll_hcoll_component.hcoll_datatype_fallback) {
78158
dte_data_rep = DTE_ZERO;
79159
dte_data_rep.rep.in_line_rep.data_handle.in_line.in_line = 0;
80160
dte_data_rep.rep.in_line_rep.data_handle.pointer_to_handle = (uint64_t ) &dtype->super;
81-
return dte_data_rep;
82161
}
83-
return *ompi_datatype_2_dte_data_rep[opal_type_id];
162+
return dte_data_rep;
84163
}
85164

86165
static hcoll_dte_op_t* ompi_op_2_hcoll_op[OMPI_OP_BASE_FORTRAN_OP_MAX + 1] = {
@@ -108,4 +187,27 @@ static hcoll_dte_op_t* ompi_op_2_hcolrte_op(ompi_op_t *op) {
108187
return ompi_op_2_hcoll_op[op->o_f_to_c_index];
109188
}
110189

190+
191+
#if HCOLL_API >= HCOLL_VERSION(3,6)
192+
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
193+
int ret = OMPI_SUCCESS;
194+
mca_coll_hcoll_dtype_t *dtype =
195+
(mca_coll_hcoll_dtype_t*) attr_val;
196+
197+
assert(dtype);
198+
if (HCOLL_SUCCESS != (ret = hcoll_dt_destroy(dtype->type))) {
199+
HCOL_ERROR("failed to delete type attr: hcoll_dte_destroy returned %d",ret);
200+
return OMPI_ERROR;
201+
}
202+
OPAL_FREE_LIST_RETURN(&mca_coll_hcoll_component.dtypes,
203+
&dtype->super);
204+
205+
return OMPI_SUCCESS;
206+
}
207+
#else
208+
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
209+
/*Do nothing - it's an old version of hcoll w/o dtypes support */
210+
return OMPI_SUCCESS;
211+
}
212+
#endif
111213
#endif /* COLL_HCOLL_DTYPES_H */

ompi/mca/coll/hcoll/coll_hcoll_module.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
#include "ompi_config.h"
1111
#include "coll_hcoll.h"
12+
#include "coll_hcoll_dtypes.h"
1213

1314
int hcoll_comm_attr_keyval;
15+
int hcoll_type_attr_keyval;
1416

1517
/*
1618
* Initial query function that is invoked during MPI_INIT, allowing
@@ -211,6 +213,10 @@ int mca_coll_hcoll_progress(void)
211213
}
212214

213215

216+
OBJ_CLASS_INSTANCE(mca_coll_hcoll_dtype_t,
217+
opal_free_list_item_t,
218+
NULL,NULL);
219+
214220
/*
215221
* Invoked when there's a new communicator that has been created.
216222
* Look at the communicator and decide which set of functions and
@@ -288,6 +294,24 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
288294
HCOL_ERROR("Hcol comm keyval create failed");
289295
return NULL;
290296
}
297+
298+
if (mca_coll_hcoll_component.derived_types_support_enabled) {
299+
copy_fn.attr_datatype_copy_fn = (MPI_Type_internal_copy_attr_function *) MPI_TYPE_NULL_COPY_FN;
300+
del_fn.attr_datatype_delete_fn = hcoll_type_attr_del_fn;
301+
err = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn, &hcoll_type_attr_keyval, NULL ,0, NULL);
302+
if (OMPI_SUCCESS != err) {
303+
cm->hcoll_enable = 0;
304+
hcoll_finalize();
305+
opal_progress_unregister(mca_coll_hcoll_progress);
306+
HCOL_ERROR("Hcol type keyval create failed");
307+
return NULL;
308+
}
309+
}
310+
OBJ_CONSTRUCT(&cm->dtypes, opal_free_list_t);
311+
opal_free_list_init(&cm->dtypes, sizeof(mca_coll_hcoll_dtype_t),
312+
OBJ_CLASS(mca_coll_hcoll_dtype_t),
313+
32, -1, 32);
314+
291315
}
292316

293317
hcoll_module = OBJ_NEW(mca_coll_hcoll_module_t);

0 commit comments

Comments
 (0)