Skip to content

[Feature] Add unified device utilities API (keras.utils.device_utils) for backend-agnostic device and memory management #22278

@amadhan882

Description

@amadhan882

Summary

Keras 3 provides a unified model-building interface across TensorFlow, JAX, and PyTorch.
However, there is no backend-agnostic API for device management (GPU listing, memory tracking, etc.).
Users must write backend-specific code, making it difficult to build portable profilers, debuggers, and memory-aware callbacks.

I propose adding:

keras.utils.device_utils

as a standardized module for device listing and memory reporting across all Keras backends.


Problem Statement

Currently, there is no unified method for hardware inspection:

from keras.utils import device_utils
# ImportError: cannot import name 'device_utils' from 'keras.utils'

This prevents Keras users from writing backend-neutral utilities, such as:

  • memory profilers
  • training monitors
  • distributed strategy resource diagnostics
  • automatic device selection

This is especially critical for Keras Core, where one codebase should run on TensorFlow, JAX, or PyTorch without modification.


Proposed Solution: Prototype Implementation

Below is a working prototype of the proposed module (keras/utils/device_utils.py).
It provides:

  • list_devices() → unified accelerator enumeration
  • get_memory_info() → unified memory reporting
  • Graceful fallback to system RAM when no GPU is found
import os
import psutil
import keras
from keras import backend

def list_devices(type="gpu"):
    """Unified list of available accelerator devices across backends."""
    curr_backend = backend.backend()
    
    if curr_backend == "tensorflow":
        import tensorflow as tf
        return tf.config.list_physical_devices(type.upper())
    
    elif curr_backend == "torch":
        import torch
        if type.lower() == "gpu" and torch.cuda.is_available():
            return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
        return []
    
    elif curr_backend == "jax":
        import jax
        try:
            return jax.devices() if type.lower() == "gpu" else []
        except:
            return []
    return []

def get_memory_info(device_id=0):
    """Returns a dict with 'allocated' and 'peak' memory in bytes."""
    curr_backend = backend.backend()
    gpu_devices = list_devices(type="gpu")

    # Graceful fallback to System RAM if no accelerator is found
    if not gpu_devices:
        process = psutil.Process(os.getpid())
        mem_info = process.memory_info()
        return {
            "allocated": mem_info.rss, 
            "peak": mem_info.vms,
            "type": "system_ram"
        }

    if curr_backend == "tensorflow":
        import tensorflow as tf
        _ = tf.zeros((1,))  # initialize GPU
        try:
            info = tf.config.experimental.get_memory_info(f"GPU:{device_id}")
            return {"allocated": info["current"], "peak": info["peak"], "type": "gpu"}
        except ValueError:
            return {"allocated": 0, "peak": 0, "type": "gpu"}
    
    elif curr_backend == "torch":
        import torch
        return {
            "allocated": torch.cuda.memory_allocated(device_id),
            "peak": torch.cuda.max_memory_allocated(device_id),
            "type": "gpu"
        }
    
    elif curr_backend == "jax":
        import jax
        try:
            stats = jax.local_devices()[device_id].memory_stats()
            return {
                "allocated": stats["bytes_in_use"],
                "peak": stats["peak_bytes_in_use"],
                "type": "gpu"
            }
        except (IndexError, RuntimeError):
            return {"allocated": 0, "peak": 0, "type": "gpu"}

    raise NotImplementedError(f"Memory stats not supported for backend {curr_backend}")


if __name__ == "__main__":
    print("Backend:", backend.backend())
    print("Devices:", list_devices())
    print("Memory Info:", get_memory_info())

Verified Proof of Concept

On a CPU-only environment (TensorFlow backend):

Backend: tensorflow
Devices: []
Memory Info: {'allocated': 852344832, 'peak': 3546464256, 'type': 'system_ram'}

This demonstrates that the API works across all backends and falls back cleanly when accelerators are not present.


Technical Benefits

Backend parity

A unified API for device enumeration and memory reporting across TensorFlow, JAX, and PyTorch.

Developer experience

Eliminates backend-specific boilerplate:

  • No need to import torch.cuda.*, tf.config.*, jax.devices() manually
  • Enables backend-agnostic callbacks and utilities

Ecosystem growth

Allows external tool developers to integrate with Keras without maintaining three separate backends.

Future-safe

Standardizes memory reporting for upcoming accelerators (NPUs, VPUs, etc.).


Request

Please consider adding:

keras/utils/device_utils.py

with the unified functions:

  • list_devices(type="gpu")
  • get_memory_info(device_id=0)

This will strongly improve backend-agnostic tooling in Keras 3.

Metadata

Metadata

Assignees

No one assigned

    Labels

    keras-team-review-pendingPending review by a Keras team member.type:featureThe user is asking for a new feature.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions