Jax out of vram using pmap #12155
Unanswered
Samuel-Fipps
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
When running a training script by google, pmap seems to be causing a out of memory error. It tries to allocate a couple of TB's of memory when I'm just increasing the batch size. I have never use jax before but I was able to pin down the location where it goes wrong.
Here is the code for the training script.
Line 870 is where the pmap is created
and line 927 is where it is used and errors out.
I put the error after the code
`
#!/usr/bin/env python
coding=utf-8
Copyright 2021 The HuggingFace Team All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Pretraining the library models for T5-like span-masked language modeling on a text file or a dataset.
Here is the full list of checkpoints on the hub that can be pretrained by this script:
https://huggingface.co/models?filter=t5
"""
import json
import logging
import math
import os
import sys
import time
import copy
from dataclasses import asdict, dataclass, field
from jax.experimental.maps import xmap
You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
from enum import Enum
from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import torch.cuda
import torch
import flax
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
AutoTokenizer,
BatchEncoding,
FlaxT5ForConditionalGeneration,
HfArgumentParser,
PreTrainedTokenizerBase,
T5Config,
is_tensorboard_available,
set_seed,
)
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
from transformers.utils import get_full_repo_name, send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@DataClass
class TrainingArguments:
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
push_to_hub: bool = field(
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
)
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local
output_dir
."})
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
@DataClass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
@DataClass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
"""This function is copy of
random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>
__ .@flax.struct.dataclass
class FlaxDataCollatorForT5MLM:
"""
Data collator used for T5 span-masked language modeling.
It is made sure that after masking the inputs are of length
data_args.max_seq_length
and targets are also of fixed length.For more information on how T5 span-masked language modeling works, one can take a look
at the
official paper <https://arxiv.org/pdf/1910.10683.pdf>
__or the
official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>
__ .def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
the batch size and
drop_last
isTrue
, the last incomplete batch is dropped. Else, it is returned."""num_samples = len(samples_idx)
if drop_last:
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
samples_idx = samples_idx.reshape((sections_split, batch_size))
else:
sections_split = math.ceil(num_samples / batch_size)
samples_idx = np.array_split(samples_idx, sections_split)
return samples_idx
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
if name == "main":
main()
`
Here is the error:
`
022-08-29 21:36:02.173212: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_1_bfc) ran out of memory trying to allocate 2.83TiB (rounded to 3109218748160)requested by op | 0/185 [00:00<?, ?it/s]
2022-08-29 21:36:02.175722: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ***********************************************************************_____________________________
2022-08-29 21:36:02.191908: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 1 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3109218748016 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 14.14GiB
constant allocation: 1.2KiB
maybe_live_out allocation: 14.14GiB
preallocated temp allocation: 2.83TiB
preallocated temp fragmentation: 222.80MiB (0.01%)
total allocation: 2.84TiB
total fragmentation: 223.14MiB (0.01%)
Peak buffers:
Buffer 1:
Size: 16.00GiB
Operator: op_name="pmap(train_step)/jit(main)/jvp(FlaxT5ForConditionalGenerationModule)/encoder/block/7/layer/1/DenseReluDense/wi/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/samuel/.local/lib/python3.8/site-packages/flax/linen/linear.py" source_line=196
XLA Label: custom-call
Shape: f32[262144,16384]
==========================
2022-08-29 21:36:02.264641: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.83TiB (rounded to 3109218748160)requested by op
2022-08-29 21:36:02.266129: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ********************************************************************________________________________
2022-08-29 21:36:02.282608: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3109218748016 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 14.14GiB
constant allocation: 1.2KiB
maybe_live_out allocation: 14.14GiB
preallocated temp allocation: 2.83TiB
preallocated temp fragmentation: 222.80MiB (0.01%)
total allocation: 2.84TiB
total fragmentation: 223.14MiB (0.01%)
Peak buffers:
Buffer 1:
Size: 16.00GiB
Operator: op_name="pmap(train_step)/jit(main)/jvp(FlaxT5ForConditionalGenerationModule)/encoder/block/7/layer/1/DenseReluDense/wi/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/samuel/.local/lib/python3.8/site-packages/flax/linen/linear.py" source_line=196
XLA Label: custom-call
Shape: f32[262144,16384]
==========================
Training...: 0%| | 0/185 [04:00<?, ?it/s]
Epoch ... : 0%| | 0/1 [04:00<?, ?it/s]
Traceback (most recent call last):
File "scripttraining.py", line 1030, in
if name == "main":
File "scripttraining.py", line 938, in main
File "/home/samuel/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/samuel/.local/lib/python3.8/site-packages/jax/src/api.py", line 2156, in cache_miss
out_tree, out_flat = f_pmapped(*args, **kwargs)
File "/home/samuel/.local/lib/python3.8/site-packages/jax/_src/api.py", line 2032, in pmap_f
out = pxla.xla_pmap(
File "/home/samuel/.local/lib/python3.8/site-packages/jax/core.py", line 2040, in bind
return map_bind(self, fun, *args, **params)
File "/home/samuel/.local/lib/python3.8/site-packages/jax/core.py", line 2072, in map_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/samuel/.local/lib/python3.8/site-packages/jax/core.py", line 2043, in process
return trace.process_map(self, fun, tracers, params)
File "/home/samuel/.local/lib/python3.8/site-packages/jax/core.py", line 687, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/samuel/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 919, in xla_pmap_impl
return compiled_fun(*args)
File "/home/samuel/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/home/samuel/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1742, in call
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3109218748016 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 14.14GiB
constant allocation: 1.2KiB
maybe_live_out allocation: 14.14GiB
preallocated temp allocation: 2.83TiB
preallocated temp fragmentation: 222.80MiB (0.01%)
total allocation: 2.84TiB
total fragmentation: 223.14MiB (0.01%)
Peak buffers:
Buffer 1:
Size: 16.00GiB
Operator: op_name="pmap(train_step)/jit(main)/jvp(FlaxT5ForConditionalGenerationModule)/encoder/block/7/layer/1/DenseReluDense/wi/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/samuel/.local/lib/python3.8/site-packages/flax/linen/linear.py" source_line=196
XLA Label: custom-call
Shape: f32[262144,16384]
==========================
: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "scripttraining.py", line 1030, in
if name == "main":
File "scripttraining.py", line 938, in main
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3109218748016 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 14.14GiB
constant allocation: 1.2KiB
maybe_live_out allocation: 14.14GiB
preallocated temp allocation: 2.83TiB
preallocated temp fragmentation: 222.80MiB (0.01%)
total allocation: 2.84TiB
total fragmentation: 223.14MiB (0.01%)
Peak buffers:
Buffer 1:
Size: 16.00GiB
Operator: op_name="pmap(train_step)/jit(main)/jvp(FlaxT5ForConditionalGenerationModule)/encoder/block/7/layer/1/DenseReluDense/wi/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/home/samuel/.local/lib/python3.8/site-packages/flax/linen/linear.py" source_line=196
XLA Label: custom-call
Shape: f32[262144,16384]
==========================
: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
`
Beta Was this translation helpful? Give feedback.
All reactions