Skip to content

Commit 27dd211

Browse files
Prompt Task/Complexity Classifier (#364)
* initial commit Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * update readmes Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * edit readmes Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * working scripts Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * run isort Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * modify base Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * change to output_path Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * Apply suggestions from code review Co-authored-by: Vibhu Jawa <vibhujawa@gmail.com> Signed-off-by: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> * change to notimplementederror Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> --------- Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> Signed-off-by: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Co-authored-by: Vibhu Jawa <vibhujawa@gmail.com>
1 parent e820b8b commit 27dd211

File tree

12 files changed

+602
-6
lines changed

12 files changed

+602
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ All of our text pipelines have great multilingual support.
2828
- [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)
2929
- Classifier Filtering
3030
- [fastText](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html)
31-
- GPU-Accelerated models: [Domain (English and multilingual), Quality, Safety, Educational Content, and Content Type Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html)
31+
- GPU-Accelerated models: [Domain (English and multilingual), Quality, Safety, Educational Content, Content Type, and Prompt Task/Complexity Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html)
3232
- **GPU-Accelerated Deduplication**
3333
- [Exact Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html)
3434
- [Fuzzy Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) via MinHash Locality Sensitive Hashing

docs/user-guide/api/classifiers.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,6 @@ Classifiers
2222

2323
.. autoclass:: nemo_curator.classifiers.ContentTypeClassifier
2424
:members:
25+
26+
.. autoclass:: nemo_curator.classifiers.PromptTaskComplexityClassifier
27+
:members:

docs/user-guide/cpuvsgpu.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ The following NeMo Curator modules are GPU based.
7272
* AEGIS and Instruction-Data-Guard Safety Models
7373
* FineWeb Educational Content Classification
7474
* Content Type Classification
75+
* Prompt Task/Complexity Classification
7576

7677
GPU modules store the ``DocumentDataset`` using a ``cudf`` backend instead of a ``pandas`` one.
7778
To read a dataset into GPU memory, one could use the following function call.

docs/user-guide/distributeddataclassification.rst

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ NeMo Curator provides a module to help users run inference with pre-trained mode
1515
This is achieved by chunking the datasets across multiple computing nodes, each equipped with multiple GPUs, to accelerate the classification task in a distributed manner.
1616
Since the classification of a single text document is independent of other documents within the dataset, we can distribute the workload across multiple nodes and GPUs to perform parallel processing.
1717

18-
Domain (English and multilingual), quality, content safety, educational content, and content type models are tasks we include as examples within our module.
18+
Domain (English and multilingual), quality, content safety, educational content, content type, and prompt task/complexity models are tasks we include as examples within our module.
1919

2020
Here, we summarize why each is useful for training an LLM:
2121

@@ -33,6 +33,8 @@ Here, we summarize why each is useful for training an LLM:
3333

3434
- The **Content Type Classifier** is designed to categorize documents into one of 11 distinct speech types based on their content. It analyzes and understands the nuances of textual information, enabling accurate classification across a diverse range of content types.
3535

36+
- The **Prompt Task/Complexity Classifier** is a multi-headed model which classifies English text prompts across task types and complexity dimensions.
37+
3638
-----------------------------------------
3739
Usage
3840
-----------------------------------------
@@ -256,6 +258,27 @@ Let's see how ``ContentTypeClassifier`` works in a small excerpt taken from ``ex
256258
In this example, the content type classifier is obtained directly from `Hugging Face <https://huggingface.co/nvidia/content-type-classifier-deberta>`_.
257259
It filters the input dataset to include only documents classified as "Blogs" or "News".
258260

261+
Prompt Task/Complexity Classifier
262+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
263+
264+
The Prompt Task/Complexity Classifier is a multi-headed model which classifies English text prompts across task types and complexity dimensions. Tasks are classified across 11 common categories. Complexity is evaluated across 6 dimensions and ensembled to create an overall complexity score.
265+
266+
Here's an example of how to use the ``PromptTaskComplexityClassifier``:
267+
268+
.. code-block:: python
269+
270+
from nemo_curator.classifiers import PromptTaskComplexityClassifier
271+
272+
files = get_all_files_paths_under("my_dataset/")
273+
input_dataset = DocumentDataset.read_json(files, backend="cudf")
274+
275+
classifier = PromptTaskComplexityClassifier()
276+
result_dataset = classifier(dataset=input_dataset)
277+
278+
result_dataset.to_json("labeled_dataset/")
279+
280+
The prompt task and complexity classifier is obtained from `Hugging Face <https://huggingface.co/nvidia/prompt-task-and-complexity-classifier>`_.
281+
259282
-----------------------------------------
260283
CrossFit Integration
261284
-----------------------------------------

examples/classifiers/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The Python scripts in this directory demonstrate how to run classification on yo
99
- Instruction-Data-Guard Model
1010
- FineWeb Educational Content Classifier
1111
- Content Type Classifier
12+
- Prompt Task/Complexity Classifier
1213

1314
For more information about these classifiers, please see NeMo Curator's [Distributed Data Classification documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html).
1415

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import time
17+
18+
from nemo_curator.classifiers import PromptTaskComplexityClassifier
19+
from nemo_curator.datasets import DocumentDataset
20+
from nemo_curator.utils.distributed_utils import get_client
21+
from nemo_curator.utils.script_utils import ArgumentHelper
22+
23+
24+
def main(args):
25+
global_st = time.time()
26+
27+
# Input can be a string or list
28+
input_file_path = "/path/to/data"
29+
output_file_path = "./"
30+
31+
client_args = ArgumentHelper.parse_client_args(args)
32+
client_args["cluster_type"] = "gpu"
33+
client = get_client(**client_args)
34+
35+
input_dataset = DocumentDataset.read_json(
36+
input_file_path, backend="cudf", add_filename=True
37+
)
38+
39+
prompt_task_complexity_classifier = PromptTaskComplexityClassifier()
40+
result_dataset = prompt_task_complexity_classifier(dataset=input_dataset)
41+
42+
result_dataset.to_json(output_path=output_file_path, write_to_filename=True)
43+
44+
global_et = time.time()
45+
print(
46+
f"Total time taken for prompt task and complexity classifier inference: {global_et-global_st} s",
47+
flush=True,
48+
)
49+
50+
client.close()
51+
52+
53+
def attach_args(
54+
parser=argparse.ArgumentParser(
55+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
56+
),
57+
):
58+
argumentHelper = ArgumentHelper(parser)
59+
argumentHelper.add_distributed_classifier_cluster_args()
60+
61+
return argumentHelper.parser
62+
63+
64+
if __name__ == "__main__":
65+
main(attach_args().parse_args())

nemo_curator/classifiers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .content_type import ContentTypeClassifier
2020
from .domain import DomainClassifier, MultilingualDomainClassifier
2121
from .fineweb_edu import FineWebEduClassifier
22+
from .prompt_task_complexity import PromptTaskComplexityClassifier
2223
from .quality import QualityClassifier
2324

2425
__all__ = [
@@ -29,4 +30,5 @@
2930
"InstructionDataGuardClassifier",
3031
"FineWebEduClassifier",
3132
"ContentTypeClassifier",
33+
"PromptTaskComplexityClassifier",
3234
]

nemo_curator/classifiers/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
os.environ["RAPIDS_NO_INITIALIZE"] = "1"
1818
from abc import ABC, abstractmethod
19-
from typing import List, Optional
19+
from typing import List, Optional, Union
2020

2121
import torch
2222
import torch.nn as nn
@@ -37,8 +37,8 @@ def __init__(
3737
labels: Optional[List[str]],
3838
filter_by: Optional[List[str]],
3939
batch_size: int,
40-
out_dim: int,
41-
pred_column: str,
40+
out_dim: Optional[int],
41+
pred_column: Union[str, List[str]],
4242
max_chars: int,
4343
device_type: str,
4444
autocast: bool,

0 commit comments

Comments
 (0)