44from typing import Callable , Iterable , Optional , TextIO
55
66import pydantic
7+ from datasets import load_dataset
78
89from .schemas import ExportMessageNode , ExportMessageTree
910
@@ -17,22 +18,24 @@ def open_jsonl_read(input_file_path: str | Path) -> TextIO:
1718 return input_file_path .open ("r" , encoding = "UTF-8" )
1819
1920
20- def read_oasst_obj (line : str ) -> ExportMessageTree | ExportMessageNode :
21- dict_tree = json .loads (line )
21+ def read_oasst_obj (obj_dict : dict ) -> ExportMessageTree | ExportMessageNode :
2222 # validate data
23- if "message_id" in dict_tree :
24- return pydantic .parse_obj_as (ExportMessageNode , dict_tree )
25- elif "message_tree_id" in dict_tree :
26- return pydantic .parse_obj_as (ExportMessageTree , dict_tree )
23+ if "message_id" in obj_dict :
24+ return pydantic .parse_obj_as (ExportMessageNode , obj_dict )
25+ elif "message_tree_id" in obj_dict :
26+ return pydantic .parse_obj_as (ExportMessageTree , obj_dict )
2727
2828 raise RuntimeError ("Unknown object in jsonl file" )
2929
3030
31- def read_oasst_jsonl (input_file_path : str | Path ) -> Iterable [ExportMessageTree | ExportMessageNode ]:
31+ def read_oasst_jsonl (
32+ input_file_path : str | Path ,
33+ ) -> Iterable [ExportMessageTree | ExportMessageNode ]:
3234 with open_jsonl_read (input_file_path ) as file_in :
3335 # read one object per line
3436 for line in file_in :
35- yield read_oasst_obj (line )
37+ dict_tree = json .loads (line )
38+ yield read_oasst_obj (dict_tree )
3639
3740
3841def read_message_trees (input_file_path : str | Path ) -> Iterable [ExportMessageTree ]:
@@ -42,18 +45,85 @@ def read_message_trees(input_file_path: str | Path) -> Iterable[ExportMessageTre
4245
4346
4447def read_message_tree_list (
45- input_file_path : str | Path , filter : Optional [Callable [[ExportMessageTree ], bool ]] = None
48+ input_file_path : str | Path ,
49+ filter : Optional [Callable [[ExportMessageTree ], bool ]] = None ,
4650) -> list [ExportMessageTree ]:
4751 return [t for t in read_message_trees (input_file_path ) if not filter or filter (t )]
4852
4953
54+ def convert_hf_message (row : dict ) -> None :
55+ emojis = row .get ("emojis" )
56+ if emojis :
57+ row ["emojis" ] = dict (zip (emojis ["name" ], emojis ["count" ]))
58+ labels = row .get ("labels" )
59+ if labels :
60+ row ["labels" ] = {
61+ name : {"value" : value , "count" : count }
62+ for name , value , count in zip (labels ["name" ], labels ["value" ], labels ["count" ])
63+ }
64+
65+
5066def read_messages (input_file_path : str | Path ) -> Iterable [ExportMessageNode ]:
5167 for x in read_oasst_jsonl (input_file_path ):
5268 assert isinstance (x , ExportMessageNode )
5369 yield x
5470
5571
5672def read_message_list (
57- input_file_path : str | Path , filter : Optional [Callable [[ExportMessageNode ], bool ]] = None
73+ input_file_path : str | Path ,
74+ filter : Optional [Callable [[ExportMessageNode ], bool ]] = None ,
5875) -> list [ExportMessageNode ]:
5976 return [t for t in read_messages (input_file_path ) if not filter or filter (t )]
77+
78+
79+ def read_dataset_message_trees (
80+ hf_dataset_name : str = "OpenAssistant/oasst1" ,
81+ split : str = "train+validation" ,
82+ ) -> Iterable [ExportMessageTree ]:
83+ dataset = load_dataset (hf_dataset_name , split = split )
84+
85+ tree_dict : dict = None
86+ parents : list = None
87+ for row in dataset :
88+ convert_hf_message (row )
89+ if row ["parent_id" ] is None :
90+ if tree_dict :
91+ tree = read_oasst_obj (tree_dict )
92+ assert isinstance (tree , ExportMessageTree )
93+ yield tree
94+
95+ tree_dict = {
96+ "message_tree_id" : row ["message_id" ],
97+ "tree_state" : row ["tree_state" ],
98+ "prompt" : row ,
99+ }
100+ parents = []
101+ else :
102+ while parents [- 1 ]["message_id" ] != row ["parent_id" ]:
103+ parents .pop ()
104+ parent = parents [- 1 ]
105+ if "replies" not in parent :
106+ parent ["replies" ] = []
107+ parent ["replies" ].append (row )
108+
109+ row .pop ("message_tree_id" , None )
110+ row .pop ("tree_state" , None )
111+ parents .append (row )
112+
113+ if tree_dict :
114+ tree = read_oasst_obj (tree_dict )
115+ assert isinstance (tree , ExportMessageTree )
116+ yield tree
117+
118+
119+ def read_dataset_messages (
120+ hf_dataset_name : str = "OpenAssistant/oasst1" ,
121+ split : str = "train+validation" ,
122+ ) -> Iterable [ExportMessageNode ]:
123+ dataset = load_dataset (hf_dataset_name , split = split )
124+
125+ for row in dataset :
126+ convert_hf_message (row )
127+ message = read_oasst_obj (row )
128+ assert isinstance (message , ExportMessageNode )
129+ yield message
0 commit comments