@@ -96,8 +96,7 @@ CachingDeviceAllocator::Block::Block(
9696 m_allocated(0 ),
9797 m_prev(nullptr ),
9898 m_next(nullptr ),
99- m_event_cnt(0 ),
100- m_owner_private_pool(nullptr ) {
99+ m_event_cnt(0 ) {
101100 auto device_cnt = at::xpu::device_count ();
102101 std::vector<DeviceStats> dev_stats;
103102}
@@ -117,8 +116,7 @@ CachingDeviceAllocator::Block::Block(
117116 m_allocated(false ),
118117 m_prev(nullptr ),
119118 m_next(nullptr ),
120- m_event_cnt(0 ),
121- m_owner_private_pool(nullptr ) {}
119+ m_event_cnt(0 ) {}
122120
123121bool CachingDeviceAllocator::Block::is_split () const {
124122 return (m_prev != nullptr ) || (m_next != nullptr );
@@ -202,39 +200,13 @@ void CachingDeviceAllocator::malloc(
202200 }
203201
204202 BlockPool* pool = nullptr ;
205- PrivatePool* private_pool = nullptr ;
206203 PoolType pool_type = PoolType::UNDEF;
207- Block* block = nullptr ;
208-
209- if (recordings_underway.size ()) {
210- // graph path, try to find the blocks pointer which related to the
211- // PrivatePool who is recording graph on current queue.
212- for (auto & entry : recordings_underway) {
213- if (entry.second (queue)) {
214- auto it1 = graph_pools.find (entry.first );
215- TORCH_INTERNAL_ASSERT (it1 != graph_pools.end ());
216- if (size <= kSmallSize ) {
217- pool_type = PoolType::SMALL_POOL;
218- pool = &it1->second ->small_blocks ;
219- } else {
220- pool_type = PoolType::LARGE_POOL;
221- pool = &it1->second ->large_blocks ;
222- }
223- private_pool = it1->second .get ();
224- }
225- }
226- }
227- // fallback check. It's not suitable to change it to 'else' statement.
228- if (pool == nullptr ) {
229- // normal path, search and return aiming block for allocation in
230- // DeviceCachingAllocator's own pool.
231- if (size <= kSmallSize ) {
232- pool_type = PoolType::SMALL_POOL;
233- pool = &small_blocks;
234- } else {
235- pool_type = PoolType::LARGE_POOL;
236- pool = &large_blocks;
237- }
204+ if (size <= kSmallSize ) {
205+ pool_type = PoolType::SMALL_POOL;
206+ pool = &small_blocks;
207+ } else {
208+ pool_type = PoolType::LARGE_POOL;
209+ pool = &large_blocks;
238210 }
239211
240212 Block search_key (curDevID, *queue, size);
@@ -253,7 +225,7 @@ void CachingDeviceAllocator::malloc(
253225 stat_types[static_cast <size_t >(StatType::AGGREGATE)] = true ;
254226 stat_types[static_cast <size_t >(get_stat_type_for_pool (pool_type))] = true ;
255227 DeviceStats& stats = get_stats_for_device (curDevID);
256- block = find_free_block ();
228+ Block* block = find_free_block ();
257229
258230 if (block == nullptr ) {
259231 void * buffer;
@@ -297,9 +269,6 @@ void CachingDeviceAllocator::malloc(
297269 Block* remaining = nullptr ;
298270 AT_ASSERT (block);
299271
300- // need to record the block's owner pool for lazy releasing
301- block->m_owner_private_pool = private_pool;
302-
303272 const bool already_split = block->is_split ();
304273 if (block->should_split (size)) {
305274 remaining = block;
@@ -448,18 +417,10 @@ void CachingDeviceAllocator::free_block(Block* block) {
448417 size_t original_block_size = block->m_size ;
449418
450419 BlockPool* pool = nullptr ;
451- if (block->m_owner_private_pool == nullptr ) {
452- if (block->m_pool_type == PoolType::LARGE_POOL) {
453- pool = &large_blocks;
454- } else if (block->m_pool_type == PoolType::SMALL_POOL) {
455- pool = &small_blocks;
456- }
457- } else {
458- if (block->m_pool_type == PoolType::LARGE_POOL) {
459- pool = &block->m_owner_private_pool ->large_blocks ;
460- } else if (block->m_pool_type == PoolType::SMALL_POOL) {
461- pool = &block->m_owner_private_pool ->small_blocks ;
462- }
420+ if (block->m_pool_type == PoolType::LARGE_POOL) {
421+ pool = &large_blocks;
422+ } else if (block->m_pool_type == PoolType::SMALL_POOL) {
423+ pool = &small_blocks;
463424 }
464425
465426 int64_t net_change_inactive_split_blocks = 0 ;
@@ -673,17 +634,6 @@ void CachingDeviceAllocator::free_cached_blocks(DeviceId di) {
673634 free_blocks (large_blocks, begin, end);
674635 find_cached_blocks_bound (di, small_blocks, begin, end);
675636 free_blocks (small_blocks, begin, end);
676-
677- // Release graph private pools
678- for (auto it = graph_pools_freeable.begin ();
679- it != graph_pools_freeable.end ();) {
680- TORCH_INTERNAL_ASSERT (it->second ->use_count == 0 );
681- free_blocks (it->second ->small_blocks , begin, end);
682- free_blocks (it->second ->large_blocks , begin, end);
683- auto erase_count = graph_pools.erase (it->first );
684- TORCH_INTERNAL_ASSERT (erase_count == 1 );
685- it = graph_pools_freeable.erase (it);
686- }
687637}
688638
689639void CachingDeviceAllocator::synchronize_and_free_events (
@@ -833,62 +783,5 @@ void CachingDeviceAllocator::dumpMemoryStatus(DeviceId deviceIndex) {
833783 .current ));
834784}
835785
836- // Called by XPUGraph::begin_recording
837- void CachingDeviceAllocator::beginAllocateToPool (
838- DeviceId deviceIndex,
839- MempoolId_t mempoolId,
840- std::function<bool (sycl::queue*)> filter) {
841- std::lock_guard<std::recursive_mutex> lock (mutex);
842- auto search_key = std::make_pair (deviceIndex, mempoolId);
843- auto it = graph_pools.find (search_key);
844- if (it == graph_pools.end ()) {
845- graph_pools.emplace (search_key, std::make_unique<PrivatePool>());
846- } else {
847- TORCH_INTERNAL_ASSERT (it->second ->use_count > 0 );
848- it->second ->use_count += 1 ;
849- }
850- for (auto it2 = recordings_underway.begin (); it2 != recordings_underway.end ();
851- ++it2) {
852- TORCH_CHECK (
853- it2->first != search_key,
854- " beginAllocateToPool: already recording to mempool_id" );
855- }
856- recordings_underway.emplace_back (search_key, std::move (filter));
857- }
858-
859- // Called by XPUGraph::end_recording
860- void CachingDeviceAllocator::endAllocateToPool (
861- DeviceId deviceIndex,
862- MempoolId_t mempoolId) {
863- std::lock_guard<std::recursive_mutex> lock (mutex);
864- auto search_key = std::make_pair (deviceIndex, mempoolId);
865- for (auto it = recordings_underway.begin (); it != recordings_underway.end ();
866- ++it) {
867- if (it->first == search_key) {
868- recordings_underway.erase (it);
869- return ;
870- }
871- }
872- TORCH_CHECK (
873- false , " endAllocateToPool: not currently recording to mempool_id" );
874- }
875-
876- // Called by XPUGraph::reset
877- void CachingDeviceAllocator::releasePool (
878- DeviceId deviceIndex,
879- MempoolId_t mempoolId) {
880- std::lock_guard<std::recursive_mutex> lock (mutex);
881- auto search_key = std::make_pair (deviceIndex, mempoolId);
882- auto it = graph_pools.find (search_key);
883- TORCH_INTERNAL_ASSERT (it != graph_pools.end ());
884- auto uc = --(it->second ->use_count );
885- TORCH_INTERNAL_ASSERT (uc >= 0 );
886- if (uc == 0 ) {
887- bool inserted =
888- graph_pools_freeable.insert ({search_key, it->second .get ()}).second ;
889- TORCH_INTERNAL_ASSERT (inserted);
890- }
891- }
892-
893786} // namespace dpcpp
894787} // namespace torch_ipex::xpu
0 commit comments