diff --git a/runtime/mount/python_lib/lplots/h5/h5ad/process_message.py b/runtime/mount/python_lib/lplots/h5/h5ad/process_message.py index 3b96253..38fedd8 100644 --- a/runtime/mount/python_lib/lplots/h5/h5ad/process_message.py +++ b/runtime/mount/python_lib/lplots/h5/h5ad/process_message.py @@ -1,4 +1,5 @@ import asyncio +import json from collections.abc import Awaitable, Callable from typing import Any @@ -19,6 +20,8 @@ alignment_is_running = False +init_data_cache: dict[tuple, dict[str, Any]] = {} + async def process_h5ad_request( msg: dict[str, Any], @@ -84,34 +87,67 @@ async def process_h5ad_request( if init_obs_key is None and init_var_key is None and len(possible_obs_keys) > 0: init_obs_key = possible_obs_keys[0] - obsm = None - index = None - recomputed_index = False - filters = None - if init_obsm_key is not None: - filters = msg.get("filters") - obsm, index, recomputed_index = get_obsm( - obj_id, adata, init_obsm_key, filters, max_visualization_cells - ) + filters = msg.get("filters") - obs = None - unique_obs = None - nrof_obs = None - counts = None - if init_obs_key is not None and init_obs_key in adata.obs: - obs, (unique_obs, counts), nrof_obs = get_obs( - obj_id, adata, init_obs_key, max_visualization_cells - ) + cache_key = ( + obj_id, + init_obsm_key, + json.dumps(filters, sort_keys=True), + init_obs_key, + init_var_key, + int(max_visualization_cells), + ) - gene_column = None - if ( - init_var_key is not None - and init_obs_key is None - and init_var_key in adata.var_names - ): - gene_column = get_obs_vector(obj_id, adata, init_var_key) + cached_value = init_data_cache.get(cache_key) + + if cached_value is None: + obsm = None + index = None + recomputed_index = False + if init_obsm_key is not None: + obsm, index, recomputed_index = get_obsm( + obj_id, adata, init_obsm_key, filters, max_visualization_cells + ) + + obs = None + unique_obs = None + nrof_obs = None + counts = None + if init_obs_key is not None and init_obs_key in adata.obs: + obs, (unique_obs, counts), nrof_obs = get_obs( + obj_id, adata, init_obs_key, max_visualization_cells + ) + + gene_column = None + if ( + init_var_key is not None + and init_obs_key is None + and init_var_key in adata.var_names + ): + gene_column = get_obs_vector(obj_id, adata, init_var_key) + + var_index, var_names = get_var_index(obj_id, adata) + + cached_value = { + "init_recomputed_index": recomputed_index, + "init_obsm_values": obsm.tolist() if obsm is not None else None, + "init_obsm_index": index.tolist() if index is not None else None, + "init_obsm_filters": filters, + "init_obs_values": obs.tolist() if obs is not None else None, + "init_obs_unique_values": ( + unique_obs.tolist() if unique_obs is not None else None + ), + "init_obs_counts": counts.tolist() if counts is not None else None, + "init_obs_nrof_values": nrof_obs, + "init_var_index": var_index.tolist(), + "init_var_names": (var_names.tolist() if var_names is not None else None), + "init_var_values": ( + gene_column.tolist() if gene_column is not None else None + ), + "init_var_key": init_var_key if init_var_key is not None else None, + } - var_index, var_names = get_var_index(obj_id, adata) + init_data_cache[cache_key] = cached_value global alignment_is_running @@ -134,32 +170,15 @@ async def process_h5ad_request( # init state with these "init_obs_key": init_obs_key, "init_obsm_key": init_obsm_key, - "init_recomputed_index": recomputed_index, - "init_obsm_values": obsm.tolist() if obsm is not None else None, - "init_obsm_index": index.tolist() if index is not None else None, - "init_obsm_filters": filters, - "init_obs_values": obs.tolist() if obs is not None else None, - "init_obs_unique_values": ( - unique_obs.tolist() if unique_obs is not None else None - ), - "init_obs_counts": counts.tolist() if counts is not None else None, - "init_obs_nrof_values": nrof_obs, - # var info - "init_var_index": var_index.tolist(), - "init_var_names": ( - var_names.tolist() if var_names is not None else None - ), - # var color by info - "init_var_values": ( - gene_column.tolist() if gene_column is not None else None - ), - "init_var_key": init_var_key if init_var_key is not None else None, + **cached_value, # alignment info "alignment_is_running": alignment_is_running, # views info "init_views": adata.uns.get("latch_views", []), # images "init_images": adata.uns.get("latch_images", {}), + "cache_key": cache_key, + "was_cached": cached_value is not None, } }, }