feat: Add support for jinja based template rendering of the dataset#438
Conversation
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
|
Thanks for making a pull request! 😃 |
tuning/data/data_handlers.py
Outdated
| except Exception as e: | ||
| raise KeyError(f"Dataset does not contain field in template. {e}") from e | ||
|
|
||
| rendered_text += tokenizer.eos_token |
There was a problem hiding this comment.
@ashokponkumar Wanted to just confirm the removal of eos_token from dataset samples in this handler. In other handlers we add eos_token and don't expect users to add it. Hence in this handler where user passes Jinja template are we expecting user to pass eos_token too? I guess in case of non-pretokenized dataset not using eos_token when using DataCollatorForCompletionOnlyLM might affect F1 score on tuned models ?
2- @dushyantbehl Can I ask how Jinja templating could be used with pre-tokenized dataset (Having input_ids and labels as columns) ?
There was a problem hiding this comment.
- I think we need a proper documentation for now and a patch where we let users choose if they want an
eos_tokenwith the data handlers or not via one argument e.g. add akwargto the data handlers likeadd_eos_tokenthis way we can let them choose what they want inside a data config.
for a data config we should not assume things like what should we do while users want to do.
for our data args we can have this added inside our code at the last data handler whatever we choose so that our data args usecases remain same.
if you feel can you take this up with this patch? to add the kwarg for eos_token to clean up the interface with users? else we can park this to a next patch.
- For pre tokenised datasets we can ignore the jinja template imo this should be applied only to non tokeniser data sets .
We can add all these things to documentation and I request you to please add documentation with this patch.
There was a problem hiding this comment.
As per offline discussion, addition of kwarg add_eos_token would be done part of this issue and hence documentation of the same would also be taken care.
Though handler documentation is added in this PR.
tuning/data/data_handlers.py
Outdated
| return {dataset_text_field: rendered_text} | ||
|
|
||
|
|
||
| def transform_placeholders(template: str) -> str: |
There was a problem hiding this comment.
@dushyantbehl @ashokponkumar Are we handling nested dataset use case also, as I see every other handler expects dataset element Dict[str, str] and not Dict[str, Dict] ?
There was a problem hiding this comment.
I think we were only handling non nested datasets apart from chat templates...can we test things out with this patch if our code works for nested datasets then can we have a change of the argument type here?
There was a problem hiding this comment.
Also if you can move to utils as we discussed in our last call. Thanks.
There was a problem hiding this comment.
As per offline discussion, handling of nested dataset would be checked and done for all handlers as part of this issue.
Also if you can move to utils as we discussed in our last call. Thanks.
Done
tests/data/test_data_handlers.py
Outdated
| template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" | ||
| formatted_dataset_field = "formatted_data_field" | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
| with pytest.raises((KeyError, TemplateSyntaxError)): |
There was a problem hiding this comment.
can we catch this error inside our code and give users a simple text error?
There was a problem hiding this comment.
TemplateSyntaxError is not needed anymore as this error comes when there is a space between placeholder variable in the template, and we are handling the space now with transform_placeholders utils function.
For KeyError the text error for user is mentioned in the handler apply_custom_data_formatting_jinja_template.
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
tests/data/test_data_handlers.py
Outdated
| # https://spdx.dev/learn/handling-license-info/ | ||
|
|
||
| # Third Party | ||
| from jinja2.exceptions import TemplateSyntaxError |
There was a problem hiding this comment.
Can you remove this import as its not used
| Expects to be run as a HF Map API function. | ||
| Args: | ||
| element: the HF Dataset element loaded from a JSON or DatasetDict object. | ||
| dataset_text_field: formatted_dataset_field. |
There was a problem hiding this comment.
Please add tokenizer to the args here.
There was a problem hiding this comment.
Also I know this is not on you but can you please fix the doc string for line 104 as well.
tuning/utils/config_utils.py
Outdated
| return pickle.loads(message_bytes) | ||
|
|
||
|
|
||
| def transform_placeholders(template: str) -> str: |
There was a problem hiding this comment.
could we please rename this function to be more descriptive?
sanitise jinja placeholders?
|
Suggested minor changes and barring those LGTM. |
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
|
Made the suggested changes. @willmj Feel free to give a final review. |
|
@ashokponkumar please review and merge. Thanks |
|
|
||
| template += tokenizer.eos_token | ||
| template = process_jinja_placeholders(template) | ||
| env = Environment(undefined=StrictUndefined) |
There was a problem hiding this comment.
Can we use SandboxedEnvironment instead? We should avoid Environment as much as possible.
There was a problem hiding this comment.
There was a problem hiding this comment.
Thanks for the suggestion @kmehant. SandboxedEnvironment definitely seems to be good in terms of security, though the time taken for Jinja rendering (or .map() of Handler) is slightly more in this case due to internal checks, but not very significant difference.
@ashokponkumar @dushyantbehl @willmj If it sounds good to go ahead with this, I created a draft PR to include usage of SandboxedEnvironment with additional error handling.
Description of the change
Added a handler
apply_custom_data_formatting_jinja_templatewhich does jinja based template rendering of the dataset.Handling of edge case:
Example template:"### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"Jinja2 by default, does not support placeholders variable names with spaces (e.g., {{Tweet text}}), which will raise an error.
Hence additional preprocessing check (function:
transform_placeholders) has been done. This checks if there is space between the placeholder variable and then process it accordingly (by modifying variable by{{element["Tweet text"]}}.Related issue number
Issue: https://github.ibm.com/ai-foundation/watson-fm-stack-tracker/issues/1470
How to verify the PR
Verify added test cases.
Was the PR tested