@@ -325,13 +325,68 @@ static inline int end_atomicity(ompi_osc_ucx_module_t *module, ucp_ep_h ep, int
325325 return OMPI_SUCCESS ;
326326}
327327
328+ static inline int get_dynamic_win_info (uint64_t remote_addr , ompi_osc_ucx_module_t * module ,
329+ ucp_ep_h ep , int target ) {
330+ ucp_rkey_h state_rkey = (module -> state_info_array )[target ].rkey ;
331+ uint64_t remote_state_addr = (module -> state_info_array )[target ].addr + OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET ;
332+ size_t len = sizeof (uint64_t ) + sizeof (ompi_osc_dynamic_win_info_t ) * OMPI_OSC_UCX_ATTACH_MAX ;
333+ char * temp_buf = malloc (len );
334+ ompi_osc_dynamic_win_info_t * temp_dynamic_wins ;
335+ int win_count , contain , insert = -1 ;
336+ ucs_status_t status ;
337+
338+ if ((module -> win_info_array [target ]).rkey_init == true) {
339+ ucp_rkey_destroy ((module -> win_info_array [target ]).rkey );
340+ (module -> win_info_array [target ]).rkey_init == false;
341+ }
342+
343+ status = ucp_get_nbi (ep , (void * )temp_buf , len , remote_state_addr , state_rkey );
344+ if (status != UCS_OK && status != UCS_INPROGRESS ) {
345+ opal_output_verbose (1 , ompi_osc_base_framework .framework_output ,
346+ "%s:%d: ucp_get_nbi failed: %d\n" ,
347+ __FILE__ , __LINE__ , status );
348+ return OMPI_ERROR ;
349+ }
350+
351+ status = ucp_ep_flush (ep );
352+ if (status != UCS_OK ) {
353+ opal_output_verbose (1 , ompi_osc_base_framework .framework_output ,
354+ "%s:%d: ucp_ep_flush failed: %d\n" ,
355+ __FILE__ , __LINE__ , status );
356+ return OMPI_ERROR ;
357+ }
358+
359+ memcpy (& win_count , temp_buf , sizeof (uint64_t ));
360+ assert (win_count > 0 && win_count <= OMPI_OSC_UCX_ATTACH_MAX );
361+
362+ temp_dynamic_wins = (ompi_osc_dynamic_win_info_t * )(temp_buf + sizeof (uint64_t ));
363+ contain = ompi_osc_find_attached_region_position (temp_dynamic_wins , 0 , win_count ,
364+ remote_addr , 1 , & insert );
365+ assert (contain >= 0 && contain < win_count );
366+
367+ status = ucp_ep_rkey_unpack (ep , temp_dynamic_wins [contain ].rkey_buffer ,
368+ & ((module -> win_info_array [target ]).rkey ));
369+ if (status != UCS_OK ) {
370+ opal_output_verbose (1 , ompi_osc_base_framework .framework_output ,
371+ "%s:%d: ucp_ep_rkey_unpack failed: %d\n" ,
372+ __FILE__ , __LINE__ , status );
373+ return OMPI_ERROR ;
374+ }
375+
376+ (module -> win_info_array [target ]).rkey_init = true;
377+
378+ free (temp_buf );
379+
380+ return status ;
381+ }
382+
328383int ompi_osc_ucx_put (const void * origin_addr , int origin_count , struct ompi_datatype_t * origin_dt ,
329384 int target , ptrdiff_t target_disp , int target_count ,
330385 struct ompi_datatype_t * target_dt , struct ompi_win_t * win ) {
331386 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
332387 ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , target );
333388 uint64_t remote_addr = (module -> win_info_array [target ]).addr + target_disp * OSC_UCX_GET_DISP (module , target );
334- ucp_rkey_h rkey = ( module -> win_info_array [ target ]). rkey ;
389+ ucp_rkey_h rkey ;
335390 bool is_origin_contig = false, is_target_contig = false;
336391 ptrdiff_t origin_lb , origin_extent , target_lb , target_extent ;
337392 ucs_status_t status ;
@@ -342,6 +397,15 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
342397 return ret ;
343398 }
344399
400+ if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
401+ status = get_dynamic_win_info (remote_addr , module , ep , target );
402+ if (status != UCS_OK ) {
403+ return OMPI_ERROR ;
404+ }
405+ }
406+
407+ rkey = (module -> win_info_array [target ]).rkey ;
408+
345409 ompi_datatype_get_true_extent (origin_dt , & origin_lb , & origin_extent );
346410 ompi_datatype_get_true_extent (target_dt , & target_lb , & target_extent );
347411
@@ -378,7 +442,7 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
378442 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
379443 ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , target );
380444 uint64_t remote_addr = (module -> win_info_array [target ]).addr + target_disp * OSC_UCX_GET_DISP (module , target );
381- ucp_rkey_h rkey = ( module -> win_info_array [ target ]). rkey ;
445+ ucp_rkey_h rkey ;
382446 ptrdiff_t origin_lb , origin_extent , target_lb , target_extent ;
383447 bool is_origin_contig = false, is_target_contig = false;
384448 ucs_status_t status ;
@@ -389,6 +453,15 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
389453 return ret ;
390454 }
391455
456+ if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
457+ status = get_dynamic_win_info (remote_addr , module , ep , target );
458+ if (status != UCS_OK ) {
459+ return OMPI_ERROR ;
460+ }
461+ }
462+
463+ rkey = (module -> win_info_array [target ]).rkey ;
464+
392465 ompi_datatype_get_true_extent (origin_dt , & origin_lb , & origin_extent );
393466 ompi_datatype_get_true_extent (target_dt , & target_lb , & target_extent );
394467
@@ -557,10 +630,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
557630 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * )win -> w_osc_module ;
558631 ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , target );
559632 uint64_t remote_addr = (module -> win_info_array [target ]).addr + target_disp * OSC_UCX_GET_DISP (module , target );
560- ucp_rkey_h rkey = ( module -> win_info_array [ target ]). rkey ;
633+ ucp_rkey_h rkey ;
561634 size_t dt_bytes ;
562635 ompi_osc_ucx_internal_request_t * req = NULL ;
563636 int ret = OMPI_SUCCESS ;
637+ ucs_status_t status ;
564638
565639 ret = check_sync_state (module , target , false);
566640 if (ret != OMPI_SUCCESS ) {
@@ -572,6 +646,15 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
572646 return ret ;
573647 }
574648
649+ if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
650+ status = get_dynamic_win_info (remote_addr , module , ep , target );
651+ if (status != UCS_OK ) {
652+ return OMPI_ERROR ;
653+ }
654+ }
655+
656+ rkey = (module -> win_info_array [target ]).rkey ;
657+
575658 ompi_datatype_type_size (dt , & dt_bytes );
576659 memcpy (result_addr , origin_addr , dt_bytes );
577660 req = ucp_atomic_fetch_nb (ep , UCP_ATOMIC_FETCH_OP_CSWAP , * (uint64_t * )compare_addr ,
@@ -604,17 +687,27 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
604687 op == & ompi_mpi_op_sum .op ) {
605688 ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , target );
606689 uint64_t remote_addr = (module -> win_info_array [target ]).addr + target_disp * OSC_UCX_GET_DISP (module , target );
607- ucp_rkey_h rkey = ( module -> win_info_array [ target ]). rkey ;
690+ ucp_rkey_h rkey ;
608691 uint64_t value = * (uint64_t * )origin_addr ;
609692 ucp_atomic_fetch_op_t opcode ;
610693 size_t dt_bytes ;
611694 ompi_osc_ucx_internal_request_t * req = NULL ;
695+ ucs_status_t status ;
612696
613697 ret = start_atomicity (module , ep , target );
614698 if (ret != OMPI_SUCCESS ) {
615699 return ret ;
616700 }
617701
702+ if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
703+ status = get_dynamic_win_info (remote_addr , module , ep , target );
704+ if (status != UCS_OK ) {
705+ return OMPI_ERROR ;
706+ }
707+ }
708+
709+ rkey = (module -> win_info_array [target ]).rkey ;
710+
618711 ompi_datatype_type_size (dt , & dt_bytes );
619712
620713 if (op == & ompi_mpi_op_replace .op ) {
@@ -789,7 +882,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
789882 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
790883 ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , target );
791884 uint64_t remote_addr = (module -> state_info_array [target ]).addr + OSC_UCX_STATE_REQ_FLAG_OFFSET ;
792- ucp_rkey_h rkey = ( module -> state_info_array [ target ]). rkey ;
885+ ucp_rkey_h rkey ;
793886 ompi_osc_ucx_request_t * ucx_req = NULL ;
794887 ompi_osc_ucx_internal_request_t * internal_req = NULL ;
795888 ucs_status_t status ;
@@ -800,6 +893,15 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
800893 return ret ;
801894 }
802895
896+ if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
897+ status = get_dynamic_win_info (remote_addr , module , ep , target );
898+ if (status != UCS_OK ) {
899+ return OMPI_ERROR ;
900+ }
901+ }
902+
903+ rkey = (module -> win_info_array [target ]).rkey ;
904+
803905 OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
804906 if (NULL == ucx_req ) {
805907 return OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
@@ -843,7 +945,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
843945 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
844946 ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , target );
845947 uint64_t remote_addr = (module -> state_info_array [target ]).addr + OSC_UCX_STATE_REQ_FLAG_OFFSET ;
846- ucp_rkey_h rkey = ( module -> state_info_array [ target ]). rkey ;
948+ ucp_rkey_h rkey ;
847949 ompi_osc_ucx_request_t * ucx_req = NULL ;
848950 ompi_osc_ucx_internal_request_t * internal_req = NULL ;
849951 ucs_status_t status ;
@@ -854,6 +956,15 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
854956 return ret ;
855957 }
856958
959+ if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
960+ status = get_dynamic_win_info (remote_addr , module , ep , target );
961+ if (status != UCS_OK ) {
962+ return OMPI_ERROR ;
963+ }
964+ }
965+
966+ rkey = (module -> win_info_array [target ]).rkey ;
967+
857968 OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
858969 if (NULL == ucx_req ) {
859970 return OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
0 commit comments