Skip to content

Commit 53033ce

Browse files
raffenetzhenggb72
authored andcommitted
mpl/ze: Simplify cache cleanup when stale handle is detected
Use a helper function instead of open coding the cache cleanup each time. Co-authored-by: Gengbin Zheng <[email protected]>
1 parent 60ed551 commit 53033ce

File tree

1 file changed

+95
-54
lines changed

1 file changed

+95
-54
lines changed

src/mpl/src/gpu/mpl_gpu_ze.c

Lines changed: 95 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ static uint32_t *subdevice_count = NULL;
107107
/* For drmfd */
108108
typedef struct _physical_device_state {
109109
int fd;
110+
int local_dev_id;
110111
int domain, bus, device, function;
111112
} physical_device_state;
112113

@@ -116,7 +117,7 @@ static physical_device_state *physical_device_states = NULL;
116117
typedef struct {
117118
const void *ptr;
118119
uint64_t mem_id;
119-
int dev_id;
120+
int shared_dev_id;
120121
int handles[2];
121122
uint32_t nhandles;
122123
UT_hash_handle hh;
@@ -189,6 +190,7 @@ static int *ipc_max_entries = NULL;
189190
typedef struct {
190191
void *ptr;
191192
uint64_t mem_id;
193+
int dev_id;
192194
UT_hash_handle hh;
193195
} MPL_ze_mem_id_entry_t;
194196

@@ -1651,6 +1653,67 @@ static int update_lru_mapped_order(void *ipc_buf, int dev_id)
16511653
goto fn_exit;
16521654
}
16531655

1656+
static int remove_stale_sender_cache(const void *ptr, uint64_t mem_id, int local_dev_id)
1657+
{
1658+
int status, mpl_err = MPL_SUCCESS;
1659+
1660+
if (physical_device_states != NULL) {
1661+
/* drmfd */
1662+
MPL_ze_gem_hash_entry_t *entry = NULL;
1663+
HASH_FIND_PTR(gem_hash, &ptr, entry);
1664+
if (entry) {
1665+
if (entry->mem_id != mem_id ||
1666+
physical_device_states[entry->shared_dev_id].local_dev_id != local_dev_id) {
1667+
goto fn_fail;
1668+
}
1669+
HASH_DEL(gem_hash, entry);
1670+
1671+
/* close GEM handle */
1672+
for (int i = 0; i < entry->nhandles; i++) {
1673+
status =
1674+
close_handle(physical_device_states[entry->shared_dev_id].fd,
1675+
entry->handles[i]);
1676+
if (status) {
1677+
break;
1678+
}
1679+
}
1680+
1681+
MPL_free(entry);
1682+
}
1683+
}
1684+
1685+
if (likely(MPL_gpu_info.specialized_cache)) {
1686+
MPL_ze_ipc_handle_entry_t *cache_entry = NULL;
1687+
MPL_ze_mem_id_entry_t *memid_entry = NULL;
1688+
1689+
if (local_dev_id == -1) {
1690+
goto fn_fail;
1691+
}
1692+
1693+
HASH_FIND(hh, ipc_cache_tracked[local_dev_id], &mem_id, sizeof(uint64_t), cache_entry);
1694+
if (cache_entry) {
1695+
free_ipc_handle_cache(cache_entry);
1696+
HASH_DELETE(hh, ipc_cache_tracked[local_dev_id], cache_entry);
1697+
MPL_free(cache_entry);
1698+
}
1699+
1700+
HASH_FIND(hh, mem_id_cache, &ptr, sizeof(void *), memid_entry);
1701+
if (memid_entry) {
1702+
if (memid_entry->mem_id != mem_id || memid_entry->dev_id != local_dev_id) {
1703+
goto fn_fail;
1704+
}
1705+
HASH_DELETE(hh, mem_id_cache, memid_entry);
1706+
MPL_free(memid_entry);
1707+
}
1708+
}
1709+
1710+
fn_exit:
1711+
return mpl_err;
1712+
fn_fail:
1713+
mpl_err = MPL_ERR_GPU_INTERNAL;
1714+
goto fn_exit;
1715+
}
1716+
16541717
/* given a local device pointer, create an IPC handle */
16551718
int MPL_gpu_ipc_handle_create(const void *ptr, MPL_gpu_device_attr * ptr_attr,
16561719
MPL_gpu_ipc_mem_handle_t * ipc_handle)
@@ -1678,18 +1741,11 @@ int MPL_gpu_ipc_handle_create(const void *ptr, MPL_gpu_device_attr * ptr_attr,
16781741
mem_id = ptr_attr->prop.id;
16791742
HASH_FIND(hh, mem_id_cache, &pbase, sizeof(void *), memid_entry);
16801743

1681-
if (memid_entry && memid_entry->mem_id != mem_id) {
1682-
HASH_FIND(hh, ipc_cache_tracked[local_dev_id], &memid_entry->mem_id, sizeof(uint64_t),
1683-
cache_entry);
1684-
if (cache_entry) {
1685-
free_ipc_handle_cache(cache_entry);
1686-
HASH_DELETE(hh, ipc_cache_tracked[local_dev_id], cache_entry);
1687-
MPL_free(cache_entry);
1688-
cache_entry = NULL;
1744+
if (memid_entry && (memid_entry->mem_id != mem_id || memid_entry->dev_id != local_dev_id)) {
1745+
mpl_err = remove_stale_sender_cache(ptr, memid_entry->mem_id, memid_entry->dev_id);
1746+
if (mpl_err != MPL_SUCCESS) {
1747+
goto fn_fail;
16891748
}
1690-
1691-
HASH_DELETE(hh, mem_id_cache, memid_entry);
1692-
MPL_free(memid_entry);
16931749
memid_entry = NULL;
16941750
}
16951751

@@ -1735,6 +1791,7 @@ int MPL_gpu_ipc_handle_create(const void *ptr, MPL_gpu_device_attr * ptr_attr,
17351791
memset(memid_entry, 0, sizeof(MPL_ze_mem_id_entry_t));
17361792
memid_entry->ptr = pbase;
17371793
memid_entry->mem_id = mem_id;
1794+
memid_entry->dev_id = local_dev_id;
17381795
HASH_ADD(hh, mem_id_cache, ptr, sizeof(void *), memid_entry, MPL_MEM_OTHER);
17391796
}
17401797
}
@@ -1750,45 +1807,16 @@ int MPL_gpu_ipc_handle_create(const void *ptr, MPL_gpu_device_attr * ptr_attr,
17501807
/* ptr must be a local device pointer and base address */
17511808
int MPL_gpu_ipc_handle_destroy(const void *ptr, MPL_pointer_attr_t * gpu_attr)
17521809
{
1753-
int status, mpl_err = MPL_SUCCESS;
1754-
MPL_ze_ipc_handle_entry_t *cache_entry = NULL;
1810+
int mpl_err = MPL_SUCCESS;
17551811
int dev_id;
17561812
uint64_t mem_id;
17571813

1758-
if (physical_device_states != NULL) {
1759-
MPL_ze_gem_hash_entry_t *entry = NULL;
1760-
HASH_FIND_PTR(gem_hash, &ptr, entry);
1761-
1762-
if (entry) {
1763-
HASH_DEL(gem_hash, entry);
1764-
1765-
/* close GEM handle */
1766-
for (int i = 0; i < entry->nhandles; i++) {
1767-
status = close_handle(physical_device_states[entry->dev_id].fd, entry->handles[i]);
1768-
if (status) {
1769-
break;
1770-
}
1771-
}
1772-
1773-
MPL_free(entry);
1774-
}
1775-
}
1776-
1777-
if (likely(MPL_gpu_info.specialized_cache)) {
1778-
dev_id = device_to_dev_id(gpu_attr->device);
1779-
if (dev_id == -1) {
1780-
goto fn_fail;
1781-
}
1782-
1783-
mem_id = gpu_attr->device_attr.prop.id;
1784-
HASH_FIND(hh, ipc_cache_tracked[dev_id], &mem_id, sizeof(uint64_t), cache_entry);
1785-
1786-
if (cache_entry != NULL) {
1787-
free_ipc_handle_cache(cache_entry);
1788-
HASH_DELETE(hh, ipc_cache_tracked[dev_id], cache_entry);
1789-
MPL_free(cache_entry);
1790-
}
1814+
mem_id = gpu_attr->device_attr.prop.id;
1815+
dev_id = device_to_dev_id(gpu_attr->device);
1816+
if (dev_id == -1) {
1817+
goto fn_fail;
17911818
}
1819+
mpl_err = remove_stale_sender_cache(ptr, mem_id, dev_id);
17921820

17931821
fn_exit:
17941822
return mpl_err;
@@ -1894,7 +1922,7 @@ int MPL_gpu_ipc_handle_map(MPL_gpu_ipc_mem_handle_t * mpl_ipc_handle, int dev_id
18941922
goto fn_exit;
18951923
}
18961924

1897-
/* free a cache entry in mmap_cache_removal */
1925+
/* free a cache entry in mmap_cache_removal in finalize stage */
18981926
int MPL_ze_mmap_handle_unmap(void *ptr, int dev_id)
18991927
{
19001928
int mpl_err = MPL_SUCCESS;
@@ -1986,7 +2014,9 @@ int MPL_gpu_ipc_handle_unmap(void *ptr)
19862014
for (int i = 0; i < cache_entry->nfds; ++i) {
19872015
close(cache_entry->fds[i]);
19882016
}
1989-
2017+
if (cache_entry->ipc_buf != ptr) {
2018+
goto fn_fail;
2019+
}
19902020
HASH_DEL(ipc_cache_mapped[dev_id], cache_entry);
19912021
MPL_free(cache_entry);
19922022
cache_entry = NULL;
@@ -2912,6 +2942,7 @@ void MPL_ze_set_fds(int num_fds, int *fds, int *bdfs)
29122942
MPL_MEM_OTHER);
29132943
for (int i = 0; i < num_fds; i++) {
29142944
physical_device_states[i].fd = fds[i];
2945+
physical_device_states[i].local_dev_id = -1;
29152946
#ifdef ZE_PCI_PROPERTIES_EXT_NAME
29162947
physical_device_states[i].domain = bdfs[4 * i];
29172948
physical_device_states[i].bus = bdfs[4 * i + 1];
@@ -2928,6 +2959,7 @@ void MPL_ze_set_fds(int num_fds, int *fds, int *bdfs)
29282959
MPL_ze_device_entry_t *device_state = device_states + d;
29292960
device_state->sys_device_index = search_physical_devices(device_state->pci);
29302961
assert(device_state->sys_device_index != -1);
2962+
physical_device_states[device_state->sys_device_index].local_dev_id = d;
29312963
}
29322964
#endif
29332965
}
@@ -2992,8 +3024,16 @@ int MPL_ze_ipc_handle_create(const void *ptr, MPL_gpu_device_attr * ptr_attr, in
29923024
HASH_FIND_PTR(gem_hash, &ptr, entry);
29933025

29943026
/* invalid entry */
2995-
if (entry && entry->mem_id != mem_id) {
2996-
MPL_ze_ipc_remove_cache_handle(ptr);
3027+
if (entry &&
3028+
(entry->mem_id != mem_id ||
3029+
physical_device_states[entry->shared_dev_id].local_dev_id != local_dev_id)) {
3030+
mpl_err =
3031+
remove_stale_sender_cache(ptr, entry->mem_id,
3032+
physical_device_states[entry->
3033+
shared_dev_id].local_dev_id);
3034+
if (mpl_err != MPL_SUCCESS) {
3035+
goto fn_fail;
3036+
}
29973037
entry = NULL;
29983038
}
29993039

@@ -3018,7 +3058,7 @@ int MPL_ze_ipc_handle_create(const void *ptr, MPL_gpu_device_attr * ptr_attr, in
30183058

30193059
entry->ptr = ptr;
30203060
entry->mem_id = mem_id;
3021-
entry->dev_id = shared_dev_id;
3061+
entry->shared_dev_id = shared_dev_id;
30223062
for (int i = 0; i < nfds; i++)
30233063
entry->handles[i] = handles[i];
30243064
entry->nhandles = nfds;
@@ -3027,7 +3067,7 @@ int MPL_ze_ipc_handle_create(const void *ptr, MPL_gpu_device_attr * ptr_attr, in
30273067

30283068
for (int i = 0; i < entry->nhandles; i++)
30293069
h.fds[i] = entry->handles[i];
3030-
h.dev_id = entry->dev_id;
3070+
h.dev_id = entry->shared_dev_id;
30313071
} else {
30323072
for (int i = 0; i < nfds; i++) {
30333073
memcpy(&h.fds[i], &ze_ipc_handle[i], sizeof(int));
@@ -3263,7 +3303,7 @@ int MPL_ze_mmap_device_pointer(void *dptr, MPL_gpu_device_attr * attr,
32633303
if (likely(MPL_gpu_info.specialized_cache)) {
32643304
HASH_FIND(hh, mem_id_cache, &pbase, sizeof(void *), memid_entry);
32653305

3266-
if (memid_entry && memid_entry->mem_id != mem_id) {
3306+
if (memid_entry && (memid_entry->mem_id != mem_id || memid_entry->dev_id != local_dev_id)) {
32673307
HASH_FIND(hh, ipc_cache_tracked[local_dev_id], &memid_entry->mem_id, sizeof(uint64_t),
32683308
cache_entry);
32693309
if (cache_entry) {
@@ -3355,6 +3395,7 @@ int MPL_ze_mmap_device_pointer(void *dptr, MPL_gpu_device_attr * attr,
33553395
memset(memid_entry, 0, sizeof(MPL_ze_mem_id_entry_t));
33563396
memid_entry->ptr = pbase;
33573397
memid_entry->mem_id = mem_id;
3398+
memid_entry->dev_id = local_dev_id;
33583399
HASH_ADD(hh, mem_id_cache, ptr, sizeof(void *), memid_entry, MPL_MEM_OTHER);
33593400
}
33603401
}

0 commit comments

Comments
 (0)