@@ -489,14 +489,18 @@ class DirectSessionFactory : public SessionFactory {
489489
490490 ResourceMgr* gpu_shared_rmgr = nullptr ;
491491#if GOOGLE_CUDA
492+ bool use_per_session_host_allocator = false ;
493+ TF_CHECK_OK (tensorflow::ReadBoolFromEnvVar (" PER_SESSION_HOSTALLOC" ,
494+ /* default_val=*/ false ,
495+ &use_per_session_host_allocator));
492496 if (use_multi_stream) {
493497 // Create shared resource for gpu devices
494498 gpu_shared_rmgr = new ResourceMgr (" localhost" );
495499 std::string gpu_dev_prefix (" /job:localhost/replica:0/task:0/device:GPU:" );
496500 for (int i = 0 ; i < session_num; ++i) {
497501 dev_rmgr_map.device_rmgr_map [gpu_dev_prefix+std::to_string (base_index+i)] =
498502 gpu_shared_rmgr;
499- if (i > 0 ) {
503+ if (use_per_session_host_allocator && i > 0 ) {
500504 dev_rmgr_map.device_rmgr_map [dev_prefix+" /device:CPU:" +std::to_string (i)] = shared_rmgr;
501505 dev_rmgr_map.device_rmgr_map [dev_prefix+" /device:cpu:" +std::to_string (i)] = shared_rmgr;
502506 dev_rmgr_map.device_rmgr_map [" /device:CPU:" +std::to_string (i)] = shared_rmgr;
@@ -571,8 +575,13 @@ class DirectSessionFactory : public SessionFactory {
571575 follower_options.config .add_per_session_devices (
572576 " /job:localhost/replica:0/task:0/device:GPU:" +
573577 std::to_string (base_index+i));
574- follower_options.config .add_per_session_devices (
575- " /job:localhost/replica:0/task:0/device:CPU:" +std::to_string (i));
578+ if (use_per_session_host_allocator) {
579+ follower_options.config .add_per_session_devices (
580+ " /job:localhost/replica:0/task:0/device:CPU:" +std::to_string (i));
581+ } else {
582+ follower_options.config .add_per_session_devices (
583+ " /job:localhost/replica:0/task:0/device:CPU:0" );
584+ }
576585 }
577586#endif // GOOGLE_CUDA
578587
0 commit comments