@@ -907,6 +907,74 @@ def __init__(
907907 super ().__init__ (data_folder , tokenizer = tokenizer , memory_mode = memory_mode , ** corpusargs )
908908
909909
910+ class AGNEWS (ClassificationCorpus ):
911+ """The AG's News Topic Classification Corpus, classifying news into 4 coarse-grained topics.
912+
913+ Labels: World, Sports, Business, Sci/Tech.
914+ """
915+
916+ def __init__ (
917+ self ,
918+ base_path : Optional [Union [str , Path ]] = None ,
919+ tokenizer : Union [bool , Tokenizer ] = SpaceTokenizer (),
920+ memory_mode = "partial" ,
921+ ** corpusargs ,
922+ ):
923+ """Instantiates AGNews Classification Corpus with 4 classes.
924+
925+ :param base_path: Provide this only if you store the AGNEWS corpus in a specific folder, otherwise use default.
926+ :param tokenizer: Custom tokenizer to use (default is SpaceTokenizer)
927+ :param memory_mode: Set to 'partial' by default. Can also be 'full' or 'none'.
928+ :param corpusargs: Other args for ClassificationCorpus.
929+ """
930+ base_path = flair .cache_root / "datasets" if not base_path else Path (base_path )
931+
932+ dataset_name = self .__class__ .__name__ .lower ()
933+
934+ data_folder = base_path / dataset_name
935+
936+ # download data from same source as in huggingface's implementations
937+ agnews_path = "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/"
938+
939+ original_filenames = ["train.csv" , "test.csv" , "classes.txt" ]
940+ new_filenames = ["train.txt" , "test.txt" ]
941+
942+ for original_filename in original_filenames :
943+ cached_path (f"{ agnews_path } { original_filename } " , Path ("datasets" ) / dataset_name / "original" )
944+
945+ data_file = data_folder / new_filenames [0 ]
946+ label_dict = []
947+ label_path = original_filenames [- 1 ]
948+
949+ # read label order
950+ with open (data_folder / "original" / label_path ) as f :
951+ for line in f :
952+ line = line .rstrip ()
953+ label_dict .append (line )
954+
955+ original_filenames = original_filenames [:- 1 ]
956+ if not data_file .is_file ():
957+ for original_filename , new_filename in zip (original_filenames , new_filenames ):
958+ with open (data_folder / "original" / original_filename , encoding = "utf-8" ) as open_fp , open (
959+ data_folder / new_filename , "w" , encoding = "utf-8"
960+ ) as write_fp :
961+ csv_reader = csv .reader (
962+ open_fp , quotechar = '"' , delimiter = "," , quoting = csv .QUOTE_ALL , skipinitialspace = True
963+ )
964+ for id_ , row in enumerate (csv_reader ):
965+ label , title , description = row
966+ # Original labels are [1, 2, 3, 4] -> ['World', 'Sports', 'Business', 'Sci/Tech']
967+ # Re-map to [0, 1, 2, 3].
968+ text = " " .join ((title , description ))
969+
970+ new_label = "__label__"
971+ new_label += label_dict [int (label ) - 1 ]
972+
973+ write_fp .write (f"{ new_label } { text } \n " )
974+
975+ super ().__init__ (data_folder , label_type = "topic" , tokenizer = tokenizer , memory_mode = memory_mode , ** corpusargs )
976+
977+
910978class STACKOVERFLOW (ClassificationCorpus ):
911979 """Stackoverflow corpus classifying questions into one of 20 labels.
912980
0 commit comments