-
I have the following situation, roughly: I have a loop where I need to load some parameters and data on every iteration. In the loop, the parameters are cast to JAX arrays by a function that I have limited control over (i.e. it is wrapped up in library code). The data is loaded as numpy arrays. There are two issues:
Here is a code block demonstrating the situation:
def load_parameters(dataframe: pd.DataFrame, index: int)
# Load parameters from dataframe as numpy
parameters_numpy = f(dataframe, index)
return jnp.asarray(parameters_numpy)
def preprocess_parameters(parameters: Array)
# Preprocess parameters; low computational burden and
# better suited for numpy
parameters_preprocessed = g(parameters)
return parameters_preprocessed
def shard_parameters(parameters: Array)
# Run jax.device_put
shard = ...
return jax.device_put(parameters, shard)
def load_data(filename):
data = np.asarray(…) # load data as numpy array
return data
@partial(jax.jit, …) # add in_sharding?
def do_compute(parameters, data, …)
output = g(parameters, data)
return output
dataframe = …
filenames = …
for index, filename in enumerate(filenames):
parameters_jax = load_parameters(dataframe, index)
parameters_jax = preprocess_parameters(parameters_jax)
parameters_jax = shard_parameters(parameters_jax)
data_numpy = load_data(filename)
output = do_compute(parameters_jax, data_numpy, …) The only thing I can think of, which feels like a hack, is to use the #
# Either load then convert back to numpy (are there array copies?)
#
dataframe = …
filenames = …
for index, filename in enumerate(filenames):
with jax.default_device("cpu"):
parameters = load_parameters(dataframe, index)
parameters = jax.tree.map(lambda x: np.asarray(x), parameters_jax)
parameters = preprocess_parameters(parameters)
parameters = shard_parameters(parameters)
data_numpy = load_data(filename)
# … or add jax.device_put
output = do_compute(parameters, data_numpy, …)
#
# Or use CPU JAX for preprocessing
#
dataframe = …
filenames = …
for index, filename in enumerate(filenames):
with jax.default_device("cpu"):
parameters = load_parameters(dataframe, index)
parameters = preprocess_parameters(parameters)
parameters = shard_parameters(parameters)
data_numpy = load_data(filename)
# … or add jax.device_put
output = do_compute(parameters, data_numpy, …) Questions |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
I think using Another option would be to use |
Beta Was this translation helpful? Give feedback.
I think using
jax.default_device('cpu')
is a good approach for what you describe.Another option would be to use
jax.experimental.io_callback
within your main program to call back to the host and do any data loading and/or preprocessing with NumPy.