Skip to content

Commit b9649dc

Browse files
committed
abi: move mpi_type_get_envelope etc. into templates
MPI_Type_get_envelop and MPI_Type_get_contents are a big mess as the MPI Forum decided not just to add big variants but also add additional arguments for the big count variants. So this necessitated enhancements to the binding infrastructure to support optional suppressing of bc and non-bc variants of prototype files. Signed-off-by: Howard Pritchard <[email protected]>
1 parent d0d0f69 commit b9649dc

File tree

10 files changed

+399
-56
lines changed

10 files changed

+399
-56
lines changed

ompi/mpi/bindings/bindings.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def main():
6161
# parser = argparse.ArgumentParser(description='C ABI binding generation code')
6262
parser_gen.add_argument('type', choices=('ompi', 'standard'),
6363
help='generate the OMPI ABI functions or the standard ABI functions')
64-
parser_gen.add_argument('--mpit', action='store_true', help='generate MPI T code generation')
64+
parser_gen.add_argument('--mpit', action='store_true', help='generate MPI T code')
65+
parser_gen.add_argument('--suppress_bc', action='store_true', help='do not generate big count variant')
66+
parser_gen.add_argument('--suppress_nbc', action='store_true', help='do not generate int count variant')
6567
parser_gen.add_argument('source_file', help='template file to use for C code generation')
6668
parser_gen.set_defaults(handler=lambda args, out: c.generate_source(args, out))
6769
args = parser.parse_args()

ompi/mpi/bindings/ompi_bindings/c.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -416,17 +416,22 @@ def print_cdefs_for_abi(out, abi_type='ompi'):
416416
out.dump('#undef OMPI_ABI_SRC')
417417
out.dump('#define OMPI_ABI_SRC 1')
418418

419-
def ompi_abi(base_name, template, out):
419+
def ompi_abi(base_name, template, out, suppress_bc=False, suppress_nbc=False):
420420
"""Generate the OMPI ABI functions."""
421421
template.print_header(out)
422-
print_profiling_header(base_name, out)
423-
print_cdefs_for_bigcount(out)
424-
print_cdefs_for_abi(out)
425-
out.dump(template.prototype.signature(base_name, abi_type='ompi'))
426-
template.print_body(func_name=base_name, out=out)
422+
if suppress_nbc == False:
423+
print_profiling_header(base_name, out)
424+
print_cdefs_for_bigcount(out)
425+
print_cdefs_for_abi(out)
426+
out.dump(template.prototype.signature(base_name, abi_type='ompi'))
427+
template.print_body(func_name=base_name, out=out)
427428
# Check if we need to generate the bigcount interface
428-
if util.prototype_has_bigcount(template.prototype):
429-
base_name_c = f'{base_name}_c'
429+
if util.prototype_has_bigcount(template.prototype) and suppress_bc == False:
430+
# there are some special cases where we need to explicitly define the bigcount functions in the template file
431+
if base_name[-2:] == "_c":
432+
base_name_c = f'{base_name}'
433+
else:
434+
base_name_c = f'{base_name}_c'
430435
print_profiling_header(base_name_c, out)
431436
print_cdefs_for_bigcount(out, enable_count=True)
432437
print_cdefs_for_abi(out)
@@ -438,7 +443,7 @@ def ompi_abi(base_name, template, out):
438443
ABI_INTERNAL_CONVERTOR = 'ompi/mpi/c/abi_converters.h'
439444

440445

441-
def standard_abi(base_name, template, out):
446+
def standard_abi(base_name, template, out, suppress_bc=False, suppress_nbc=False):
442447
"""Generate the standard ABI functions."""
443448
template.print_header(out)
444449
out.dump(f'#include "{ABI_INTERNAL_HEADER}"')
@@ -457,14 +462,15 @@ def standard_abi(base_name, template, out):
457462
out.dump(line)
458463

459464
# Static internal function (add a random component to avoid conflicts)
460-
internal_name = f'ompi_abi_{template.prototype.name}'
461-
print_cdefs_for_bigcount(out)
462-
print_cdefs_for_abi(out, abi_type='standard')
463-
internal_sig = template.prototype.signature(internal_name, abi_type='ompi',
464-
enable_count=False)
465-
out.dump(consts.INLINE_ATTRS, internal_sig)
466-
template.print_body(func_name=base_name, out=out)
467-
if util.prototype_has_bigcount(template.prototype):
465+
if suppress_nbc == False:
466+
internal_name = f'ompi_abi_{template.prototype.name}'
467+
print_cdefs_for_bigcount(out)
468+
print_cdefs_for_abi(out, abi_type='standard')
469+
internal_sig = template.prototype.signature(internal_name, abi_type='ompi',
470+
enable_count=False)
471+
out.dump(consts.INLINE_ATTRS, internal_sig)
472+
template.print_body(func_name=base_name, out=out)
473+
if util.prototype_has_bigcount(template.prototype) and suppress_bc == False:
468474
internal_name = f'ompi_abi_{template.prototype.name}_c'
469475
print_cdefs_for_bigcount(out, enable_count=True)
470476
print_cdefs_for_abi(out, abi_type='standard')
@@ -502,10 +508,14 @@ def generate_function(prototype, fn_name, internal_fn, out, enable_count=False):
502508
out.dump(line)
503509
out.dump('}')
504510

505-
internal_name = f'ompi_abi_{template.prototype.name}'
506-
generate_function(template.prototype, base_name, internal_name, out)
507-
if util.prototype_has_bigcount(template.prototype):
508-
base_name_c = f'{base_name}_c'
511+
if suppress_nbc == False:
512+
internal_name = f'ompi_abi_{template.prototype.name}'
513+
generate_function(template.prototype, base_name, internal_name, out)
514+
if util.prototype_has_bigcount(template.prototype) and suppress_bc == False:
515+
if base_name[-2:] == "_c":
516+
base_name_c = f'{base_name}'
517+
else:
518+
base_name_c = f'{base_name}_c'
509519
internal_name = f'ompi_abi_{template.prototype.name}_c'
510520
generate_function(template.prototype, base_name_c, internal_name, out,
511521
enable_count=True)
@@ -529,6 +539,6 @@ def generate_source(args, out):
529539
else:
530540
base_name = util.mpi_fn_name_from_base_fn_name(template.prototype.name)
531541
if args.type == 'ompi':
532-
ompi_abi(base_name, template, out)
542+
ompi_abi(base_name, template, out, args.suppress_bc, args.suppress_nbc)
533543
else:
534-
standard_abi(base_name, template, out)
544+
standard_abi(base_name, template, out, args.suppress_bc, args.suppress_nbc)

ompi/mpi/bindings/ompi_bindings/c_type.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@ def parameter(self, enable_count=False, **kwargs):
139139
count_type = 'MPI_Count' if enable_count else 'int'
140140
return f'const {count_type} {self.name}[]'
141141

142+
@Type.add_type('COUNT_ARRAY_OUT')
143+
class TypeCountArrayOut(TypeCountArray):
144+
"""Array of counts out (either int or MPI_Count)."""
145+
146+
def parameter(self, enable_count=False, **kwargs):
147+
count_type = 'MPI_Count' if enable_count else 'int'
148+
return f'{count_type} {self.name}[]'
149+
142150
@Type.add_type('AINT_COUNT_ARRAY')
143151
class TypeAintCountArray(Type):
144152
"""Array of counts (either MPI_Aint or MPI_Count)."""
@@ -154,6 +162,14 @@ def parameter(self, enable_count=False, **kwargs):
154162
count_type = 'MPI_Count' if enable_count else 'MPI_Aint'
155163
return f'const {count_type} {self.name}[]'
156164

165+
@Type.add_type('AINT_COUNT_ARRAY_OUT')
166+
class TypeAintCountArrayOut(TypeAintCountArray):
167+
"""Array of counts (either MPI_Aint or MPI_Count)."""
168+
169+
def parameter(self, enable_count=False, **kwargs):
170+
count_type = 'MPI_Count' if enable_count else 'MPI_Aint'
171+
return f'{count_type} {self.name}[]'
172+
157173
@Type.add_type('ELEMENT_COUNT')
158174
class ElementCountType(Type):
159175
"""Special count type for MPI_Get_element_x"""
@@ -226,6 +242,11 @@ def type_text(self, enable_count=False):
226242
def parameter(self, enable_count=False, **kwargs):
227243
return f'const MPI_Aint {self.name}[]'
228244

245+
@Type.add_type('AINT_ARRAY_OUT')
246+
class TypeAintArrayOut(TypeAintArray):
247+
248+
def parameter(self, enable_count=False, **kwargs):
249+
return f'MPI_Aint {self.name}[]'
229250

230251
@Type.add_type('INT_OUT')
231252
class TypeIntOut(Type):
@@ -282,6 +303,15 @@ def type_text(self, enable_count=False):
282303
def parameter(self, enable_count=False, **kwargs):
283304
return f'const int {self.name}[]'
284305

306+
@Type.add_type('INT_ARRAY_OUT')
307+
class TypeIntArrayOut(TypeIntArray):
308+
309+
def type_text(self, enable_count=False):
310+
return 'int *'
311+
312+
def parameter(self, enable_count=False, **kwargs):
313+
return f'int {self.name}[]'
314+
285315
@Type.add_type('INT_AINT_OUT')
286316
class TypeIntAintOut(Type):
287317

@@ -362,6 +392,14 @@ def type_text(self, enable_count=False):
362392
def parameter(self, enable_count=False, **kwargs):
363393
return f'const {self.type_text(enable_count=enable_count)} {self.name}[]'
364394

395+
@Type.add_type('DATATYPE_ARRAY_OUT', abi_type=['ompi'])
396+
class TypeDatatypeArrayOut(Type):
397+
398+
def type_text(self, enable_count=False):
399+
return 'MPI_Datatype'
400+
401+
def parameter(self, enable_count=False, **kwargs):
402+
return f'{self.type_text(enable_count=enable_count)} {self.name}[]'
365403

366404
class StandardABIType(Type):
367405

@@ -406,9 +444,6 @@ def type_text(self, enable_count=False):
406444
def argument(self):
407445
return f'(MPI_Datatype *) {self.name}'
408446

409-
#
410-
# TODO THIS IS NOT COMPLETE
411-
#
412447
@Type.add_type('DATATYPE_ARRAY', abi_type=['standard'])
413448
class TypeDatatypeArrayStandard(StandardABIType):
414449

@@ -444,6 +479,38 @@ def parameter(self, enable_count=False, **kwargs):
444479
def argument(self):
445480
return f'(MPI_Datatype *) {self.tmpname}'
446481

482+
@Type.add_type('DATATYPE_ARRAY_OUT', abi_type=['standard'])
483+
class TypeDatatypeArrayOutStandard(StandardABIType):
484+
485+
@property
486+
def init_code(self):
487+
code = [f'int size_{self.tmpname} = {self.count_param};']
488+
code.append(f'MPI_Datatype *{self.tmpname} = (MPI_Datatype *)malloc({self.count_param} * sizeof(MPI_Datatype));')
489+
return code
490+
491+
@property
492+
def final_code(self):
493+
code = [f'for(int i=0;i<size_{self.tmpname};i++){{']
494+
code.append(f'{self.name}[i] = {ConvertOMPIToStandard.DATATYPE}({self.tmpname}[i]);')
495+
code.append(f'}}')
496+
code.append(f'free({self.tmpname});')
497+
return code
498+
499+
@property
500+
def tmpname(self):
501+
return f'{self.name}_tmp'
502+
503+
def type_text(self, enable_count=False):
504+
return self.mangle_name('MPI_Datatype')
505+
506+
def parameter(self, enable_count=False, **kwargs):
507+
return f'{self.type_text(enable_count=enable_count)} {self.name}[]'
508+
509+
@property
510+
def argument(self):
511+
return f'(MPI_Datatype *) {self.tmpname}'
512+
513+
447514
@Type.add_type('OP', abi_type=['ompi'])
448515
class TypeDatatype(Type):
449516

ompi/mpi/bindings/ompi_bindings/parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# $HEADER$
99

1010
import os
11+
import sys
1112

1213
"""Source parsing code."""
1314

@@ -16,7 +17,11 @@ class Parameter:
1617
def __init__(self, text, type_constructor):
1718
"""Parse a parameter."""
1819
# parameter in the form "TYPE NAME" or "TYPE NAME:COUNT_VAR"
19-
type_name, namecount = text.split()
20+
try:
21+
type_name, namecount = text.split()
22+
except Exception as e:
23+
print(f"Error: could not split '{text}' got error {e}")
24+
sys.exit(-1)
2025
if ':' in namecount:
2126
name, count_param = namecount.split(':')
2227
else:

ompi/mpi/c/Makefile.am

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,14 @@ prototype_sources = \
466466
win_wait.c.in \
467467
wtime.c.in
468468

469+
prototype_sources_nbc = \
470+
type_get_contents.c.in_nbc \
471+
type_get_envelope.c.in_nbc
472+
473+
prototype_sources_obc = \
474+
type_get_contents_c.c.in_obc \
475+
type_get_envelope_c.c.in_obc
476+
469477
# See MPI-5 standard Chapter 20 section 4
470478
prototype_sources_not_in_abi = \
471479
comm_c2f.c.in \
@@ -498,6 +506,8 @@ prototype_sources_not_in_abi = \
498506
win_f2c.c.in
499507

500508
EXTRA_DIST = $(prototype_sources) \
509+
$(prototype_sources_nbc) \
510+
$(prototype_sources_obc) \
501511
$(prototype_sources_not_in_abi) \
502512
abi_converters.h \
503513
abi_get_info.c.in
@@ -525,8 +535,9 @@ nobase_include_HEADERS = abi.h standard_abi/mpi.h
525535
#
526536
#
527537
interface_profile_sources = $(prototype_sources:.c.in=_ompi_generated.c) \
528-
$(prototype_sources_not_in_abi:.c.in=_ompi_generated.c)
529-
538+
$(prototype_sources_not_in_abi:.c.in=_ompi_generated.c) \
539+
$(prototype_sources_nbc:.c.in_nbc=_ompi_generated.c) \
540+
$(prototype_sources_obc:.c.in_obc=_ompi_generated.c)
530541

531542
# Conditionally install the header files
532543
if WANT_INSTALL_HEADERS
@@ -539,10 +550,6 @@ endif
539550
#
540551
extra_interface_profile_sources = \
541552
pcontrol.c \
542-
type_get_contents.c \
543-
type_get_contents_c.c \
544-
type_get_envelope.c \
545-
type_get_envelope_c.c \
546553
wtick.c
547554

548555
# The following functions were removed from the MPI standard, but are
@@ -580,6 +587,51 @@ if OMPI_GENERATE_BINDINGS
580587
ompi \
581588
$<
582589

590+
# Deal with oddballs wrt big count
591+
type_get_contents_ompi_generated.c: type_get_contents.c.in_nbc
592+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
593+
--builddir $(abs_top_builddir) \
594+
--srcdir $(abs_top_srcdir) \
595+
--output $@ \
596+
c \
597+
source \
598+
ompi \
599+
--suppress_bc \
600+
$<
601+
602+
type_get_envelope_ompi_generated.c: type_get_envelope.c.in_nbc
603+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
604+
--builddir $(abs_top_builddir) \
605+
--srcdir $(abs_top_srcdir) \
606+
--output $@ \
607+
c \
608+
source \
609+
ompi \
610+
--suppress_bc \
611+
$<
612+
613+
type_get_contents_c_ompi_generated.c: type_get_contents_c.c.in_obc
614+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
615+
--builddir $(abs_top_builddir) \
616+
--srcdir $(abs_top_srcdir) \
617+
--output $@ \
618+
c \
619+
source \
620+
ompi \
621+
--suppress_nbc \
622+
$<
623+
624+
type_get_envelope_c_ompi_generated.c: type_get_envelope_c.c.in_obc
625+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
626+
--builddir $(abs_top_builddir) \
627+
--srcdir $(abs_top_srcdir) \
628+
--output $@ \
629+
c \
630+
source \
631+
ompi \
632+
--suppress_nbc \
633+
$<
634+
583635
# Non-mangled version
584636
standard_abi/mpi.h: $(top_srcdir)/docs/mpi-standard-apis.json $(top_srcdir)/ompi/mpi/bindings/c_header.py
585637
mkdir -p standard_abi
@@ -607,6 +659,51 @@ abi.h: $(top_srcdir)/docs/mpi-standard-apis.json $(top_srcdir)/ompi/mpi/bindings
607659
source \
608660
standard \
609661
$<
662+
663+
# Deal with oddballs wrt big count
664+
type_get_contents_abi_generated.c: type_get_contents.c.in_nbc
665+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
666+
--builddir $(abs_top_builddir) \
667+
--srcdir $(abs_top_srcdir) \
668+
--output $@ \
669+
c \
670+
source \
671+
standard \
672+
--suppress_bc \
673+
$<
674+
675+
type_get_envelope_abi_generated.c: type_get_envelope.c.in_nbc
676+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
677+
--builddir $(abs_top_builddir) \
678+
--srcdir $(abs_top_srcdir) \
679+
--output $@ \
680+
c \
681+
source \
682+
standard \
683+
--suppress_bc \
684+
$<
685+
686+
type_get_contents_c_abi_generated.c: type_get_contents_c.c.in_obc
687+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
688+
--builddir $(abs_top_builddir) \
689+
--srcdir $(abs_top_srcdir) \
690+
--output $@ \
691+
c \
692+
source \
693+
standard \
694+
--suppress_nbc \
695+
$<
696+
697+
type_get_envelope_c_abi_generated.c: type_get_envelope_c.c.in_obc
698+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
699+
--builddir $(abs_top_builddir) \
700+
--srcdir $(abs_top_srcdir) \
701+
--output $@ \
702+
c \
703+
source \
704+
standard \
705+
--suppress_nbc \
706+
$<
610707
endif
611708

612709
MAINTAINERCLEANFILES = *_generated.c abi_get_info.c $(nobase_include_HEADERS)

0 commit comments

Comments
 (0)