-
Notifications
You must be signed in to change notification settings - Fork 113
Open
Labels
questionFurther information is requestedFurther information is requested
Description
Hi, while trying to use the following snippet:
import texar.torch as tx
from texar.torch.run import *
# (1) Modeling
class BERTGPT2Model(nn.Module):
"""An encoder-decoder model with GPT-2 as the decoder."""
def __init__(self, vocab_size):
super().__init__()
# Use hyperparameter dict for model configuration
self.tokeniserBERT = tx.data.BERTTokenizer('bert-base-uncased)
self.tokeniserGPT2 = tx.data.GPT2Tokenizer('gpt2-medium')
self.encoder = modules.BERTEncoder('bert-base-uncased')
self.decoder = tx.modules.GPT2Decoder("gpt2-medium") # With pre-trained weights
def _get_decoder_output(self, batch, train=True):
"""Perform model inference, i.e., decoding."""
enc_states,_ = self.encoder(inputs=self.embedder(batch['source_text_ids']),
sequence_length=batch['source_length'])
if train: # Teacher-forcing decoding at training time
return self.decoder(
inputs=batch['target_text_ids'], sequence_length=batch['target_length'] - 1,
memory=enc_states, memory_sequence_length=batch['source_length'])
else: # Beam search decoding at prediction time
start_tokens = torch.full_like(batch['source_text_ids'][:, 0], BOS) # which BOS to use?
return self.decoder(
beam_width=5, start_tokens=start_tokens,
memory=enc_states, memory_sequence_length=batch['source_length'])
def forward(self, batch):
"""Compute training loss."""
outputs = self._get_decoder_output(batch)
loss = tx.losses.sequence_sparse_softmax_cross_entropy( # Sequence loss
labels=batch['target_text_ids'][:, 1:], logits=outputs.logits,
sequence_length=batch['target_length'] - 1) # Automatic masking
return {"loss": loss}
def predict(self, batch):
"""Compute model predictions."""
sequence, _ = self._get_decoder_output(batch, train=False)
return {"gen_text_ids": sequence}
# (2) Data
# Create dataset splits using built-in data loaders
datasets = {split: tx.data.PairedTextData(hparams=data_hparams[split])
for split in ["train", "valid", "test"]}
model = BERTGPT2Model(datasets["train"].target_vocab.size)
# (3) Training
# Manage the train-eval loop with the Executor API
executor = Executor(
model=model, datasets=datasets,
optimizer={"type": torch.optim.Adam, "kwargs": {"lr": 5e-4}},
stop_training_on=cond.epoch(20),
log_every=cond.iteration(100),
validate_every=cond.epoch(1),
train_metric=("loss", metric.RunningAverage(10, pred_name="loss")),
valid_metric=metric.BLEU(pred_name="gen_text_ids", label_name="target_text_ids"),
save_every=cond.validation(better=True),
checkpoint_dir="outputs/saved_models/")
executor.train()
executor.test(datasets["test"]) In this example
- How should i use data iterators from files
- Data config for generating the file from source text to tokeniserBERT.encode_text(src) and target text tokeniserGPT2.encode_text(tgt) so that it can pass through the batch.
- does PairedTextData has an option to pass different processors in above use case.
TIA
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested