Skip to content

Commit 2ef211a

Browse files
committed
weights and source out support/fixes
fixes problems with dist graph constructors and cart_shift output. Signed-off-by: Howard Pritchard <[email protected]>
1 parent 0536c3a commit 2ef211a

File tree

6 files changed

+85
-17
lines changed

6 files changed

+85
-17
lines changed

ompi/mpi/bindings/ompi_bindings/c.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,14 +307,17 @@ def generate_new_datatype_convert_fn_intern_to_abi(self):
307307
self.dump('}')
308308

309309
def generic_convert(self, fn_name, param_name, type_, value_names, offset=None):
310-
if type_ not in c_intrinsic_types:
311-
if (type_[-1] == '*'):
312-
intern_type = self.mangle_name(type_[:-1].strip())
313-
intern_type = intern_type + ' *'
314-
else:
315-
intern_type = self.mangle_name(type_)
310+
is_ptr_arg = False
311+
tmp_type = type_
312+
if (tmp_type[-1] == '*'):
313+
is_ptr_arg = True
314+
tmp_type = tmp_type[:-1].strip()
315+
if tmp_type not in c_intrinsic_types:
316+
intern_type = self.mangle_name(tmp_type)
316317
else:
317-
intern_type = type_
318+
intern_type = tmp_type
319+
if (is_ptr_arg == True):
320+
intern_type = intern_type + ' *'
318321
self.dump(f'{consts.INLINE_ATTRS} {type_} {fn_name}({intern_type} {param_name})')
319322
self.dump('{')
320323
lines = []
@@ -447,6 +450,9 @@ def generate_tag_convert_fn_intern_to_abi(self):
447450
def generate_source_convert_fn(self):
448451
self.generic_convert(ConvertFuncs.SOURCE, 'source', 'int', consts.RESERVED_SOURCE)
449452

453+
def generate_source_convert_fn_intern_to_abi(self):
454+
self.generic_convert_reverse(ConvertOMPIToStandard.SOURCE, 'tag', 'int', consts.RESERVED_SOURCE)
455+
450456
def generate_root_convert_fn(self):
451457
self.generic_convert(ConvertFuncs.ROOT, 'root', 'int', consts.RESERVED_ROOT)
452458

@@ -504,6 +510,9 @@ def generate_t_cb_safety_convert_fn(self):
504510
def generate_comm_split_type_convert_fn(self):
505511
self.generic_convert(ConvertFuncs.SPLIT_TYPE, 'split_type', 'int', consts.COMMUNICATOR_SPLIT_TYPES)
506512

513+
def generate_weight_convert_fn(self):
514+
self.generic_convert(ConvertFuncs.WEIGHTS, 'weights', 'int *', consts.RESERVED_WEIGHTS)
515+
507516
def generate_pointer_convert_fn(self, type_, fn_name, constants):
508517
abi_type = self.mangle_name(type_)
509518
self.dump(f'{consts.INLINE_ATTRS} void {fn_name}({abi_type} *ptr)')
@@ -636,19 +645,21 @@ def dump_code(self):
636645
self.generate_t_source_order_convert_fn_intern_to_abi()
637646
self.generate_pvar_class_convert_fn()
638647
self.generate_pvar_class_convert_fn_intern_to_abi()
648+
self.generate_source_convert_fn()
649+
self.generate_source_convert_fn_intern_to_abi()
639650

640651
#
641652
# the following only need abi to intern converters
642653
#
643654
self.generate_comm_copy_attr_convert_fn()
644655
self.generate_comm_delete_attr_convert_fn()
645656
self.generate_comm_split_type_convert_fn()
657+
self.generate_weight_convert_fn()
646658

647659
#
648660
# the following only need intern to abi converters
649661
#
650662
self.generate_comm_cmp_convert_fn_intern_to_abi()
651-
self.generate_source_convert_fn()
652663
self.generate_root_convert_fn()
653664
self.generate_t_cb_safety_convert_fn()
654665

ompi/mpi/bindings/ompi_bindings/c_type.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -626,12 +626,31 @@ def init_code(self):
626626
def type_text(self, enable_count=False):
627627
return 'int'
628628

629-
@Type.add_type('COMM', abi_type=['ompi'])
630-
class TypeCommunicator(Type):
629+
@Type.add_type('SOURCE_OUT', abi_type=['ompi'])
630+
class TypeSourceOut(Type):
631631

632632
def type_text(self, enable_count=False):
633-
return 'MPI_Comm'
633+
return 'int *'
634+
635+
@Type.add_type('SOURCE_OUT', abi_type=['standard'])
636+
class TypeSourceOutStandard(StandardABIType):
634637

638+
@property
639+
def final_code(self):
640+
return [f'*{self.name} = {ConvertOMPIToStandard.SOURCE}(*{self.name});']
641+
642+
def type_text(self, enable_count=False):
643+
return f'int *'
644+
645+
@property
646+
def argument(self):
647+
return f'(int *) {self.name}'
648+
649+
@Type.add_type('COMM', abi_type=['ompi'])
650+
class TypeCommunicator(Type):
651+
652+
def type_text(self, enable_count=False):
653+
return 'MPI_Comm'
635654

636655
@Type.add_type('COMM', abi_type=['standard'])
637656
class TypeCommunicatorStandard(StandardABIType):
@@ -2522,6 +2541,36 @@ def init_code(self):
25222541
def type_text(self, enable_count=False):
25232542
return 'int'
25242543

2544+
@Type.add_type('WEIGHTS', abi_type=['ompi'])
2545+
class TypeWeightType(Type):
2546+
2547+
def type_text(self, enable_count=False):
2548+
return 'const int *'
2549+
2550+
def parameter(self, enable_count=False, **kwargs):
2551+
return f'const int {self.name}[]'
2552+
2553+
#
2554+
# TODO this can be made better if we could handle "const int"
2555+
# better as arg to the converter code.
2556+
#
2557+
@Type.add_type('WEIGHTS', abi_type=['standard'])
2558+
class TyperWeightStandard(StandardABIType):
2559+
2560+
@property
2561+
def init_code(self):
2562+
return [f'int *{self.tmpname} = (int *){ConvertFuncs.WEIGHTS}((int *){self.name});']
2563+
2564+
def type_text(self, enable_count=False):
2565+
return 'const int *'
2566+
2567+
def parameter(self, enable_count=False, **kwargs):
2568+
return f'const int * {self.name}'
2569+
2570+
@property
2571+
def argument(self):
2572+
return f'(int *){self.tmpname}'
2573+
25252574
@Type.add_type('COMM_CMP_OUT', abi_type=['ompi'])
25262575
class TypeCommCmpOut(Type):
25272576

ompi/mpi/bindings/ompi_bindings/consts.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,19 @@
306306

307307
RESERVED_SOURCE = [
308308
'MPI_ANY_SOURCE',
309+
'MPI_PROC_NULL',
309310
]
310311

311312
RESERVED_ROOT = [
312313
'MPI_ROOT',
313314
'MPI_PROC_NULL',
314315
]
315316

317+
RESERVED_WEIGHTS = [
318+
'MPI_UNWEIGHTED',
319+
'MPI_WEIGHTS_EMPTY'
320+
]
321+
316322
RESERVED_PVAR_SESSIONS = [
317323
'MPI_T_PVAR_SESSION_NULL',
318324
]
@@ -454,6 +460,7 @@ class ConvertFuncs:
454460
COMM_COPY_ATTR_FUNCTION = 'ompi_convert_comm_copy_attr_fn_intern_comm_copy_attr_fn'
455461
COMM_DELETE_ATTR_FUNCTION = 'ompi_convert_comm_delete_attr_fn_intern_comm_delete_attr_fn'
456462
SPLIT_TYPE = 'ompi_convert_split_type_intern_type'
463+
WEIGHTS = 'ompi_convert_weight_intern_weight'
457464

458465

459466
class ConvertOMPIToStandard:
@@ -484,6 +491,7 @@ class ConvertOMPIToStandard:
484491
T_SOURCE_ORDER = 'ompi_convert_source_order_ompi_to_standard'
485492
ATTR_KEY = 'ompi_convert_attr_key_ompi_to_standard'
486493
COMM_CMP = 'ompi_convert_comm_cmp_ompi_to_standard'
494+
SOURCE = 'ompi_convert_source_ompi_to_standard'
487495

488496

489497
# Inline function attributes

ompi/mpi/c/cart_shift.c.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* Copyright (c) 2012-2013 Inria. All rights reserved.
1414
* Copyright (c) 2015 Research Organization for Information Science
1515
* and Technology (RIST). All rights reserved.
16-
* Copyright (c) 2024 Triad National Security, LLC. All rights
16+
* Copyright (c) 2024-2025 Triad National Security, LLC. All rights
1717
* reserved.
1818
* $COPYRIGHT$
1919
*
@@ -32,7 +32,7 @@
3232
#include "ompi/memchecker.h"
3333

3434
PROTOTYPE ERROR_CLASS cart_shift(COMM comm, INT direction, INT disp,
35-
INT_OUT rank_source, INT_OUT rank_dest)
35+
SOURCE_OUT rank_source, SOURCE_OUT rank_dest)
3636
{
3737
int err;
3838

ompi/mpi/c/dist_graph_create.c.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#include "ompi/mca/topo/base/base.h"
2929

3030
PROTOTYPE ERROR_CLASS dist_graph_create(COMM comm_old, INT n, INT_ARRAY sources,
31-
INT_ARRAY degrees, INT_ARRAY destinations, INT_ARRAY weights,
31+
INT_ARRAY degrees, INT_ARRAY destinations, WEIGHTS weights,
3232
INFO info, INT reorder, COMM_OUT newcomm)
3333
{
3434
mca_topo_base_module_t* topo;

ompi/mpi/c/dist_graph_create_adjacent.c.in

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* Copyright (c) 2015 Research Organization for Information Science
1414
* and Technology (RIST). All rights reserved.
1515
* Copyright (c) 2017 IBM Corporation. All rights reserved.
16-
* Copyright (c) 2024 Triad National Security, LLC. All rights
16+
* Copyright (c) 2024-2025 Triad National Security, LLC. All rights
1717
* reserved.
1818
* $COPYRIGHT$
1919
*
@@ -35,8 +35,8 @@
3535

3636
PROTOTYPE ERROR_CLASS dist_graph_create_adjacent(COMM comm_old,
3737
INT indegree, INT_ARRAY sources,
38-
INT_ARRAY sourceweights, INT outdegree,
39-
INT_ARRAY destinations, INT_ARRAY destweights,
38+
WEIGHTS sourceweights, INT outdegree,
39+
INT_ARRAY destinations, WEIGHTS destweights,
4040
INFO info, INT reorder,
4141
COMM_OUT comm_dist_graph)
4242
{

0 commit comments

Comments
 (0)