|
1 |
| -import io |
2 | 1 | import logging
|
3 | 2 | import os
|
4 | 3 | import pickle
|
5 |
| -import tarfile |
6 | 4 | from collections import namedtuple
|
7 | 5 |
|
8 | 6 | import numpy as np
|
9 | 7 | import pytest
|
10 |
| -import requests |
11 | 8 | from absl.testing import parameterized
|
12 |
| -from datasets import load_dataset |
13 | 9 |
|
14 | 10 | from keras.src import backend
|
15 | 11 | from keras.src import layers
|
|
27 | 23 | logging.basicConfig(level=logging.INFO)
|
28 | 24 |
|
29 | 25 |
|
30 |
| -def get_dataset_text(dataset_identifier: str, nsamples=1000) -> str: |
31 |
| - """ |
32 |
| - Loads a specified dataset and extracts its text content into a |
33 |
| - single string. |
34 |
| - """ |
35 |
| - DATASET_CONFIGS = { |
36 |
| - "wikitext2": { |
37 |
| - "name": "wikitext", |
38 |
| - "config": "wikitext-2-raw-v1", |
39 |
| - "split": "test", |
40 |
| - "text_column": "text", |
41 |
| - }, |
42 |
| - "ptb": { |
43 |
| - "name": "ptb_text_only", |
44 |
| - "config": "penn_treebank", |
45 |
| - "split": "validation", |
46 |
| - "text_column": "sentence", |
47 |
| - }, |
48 |
| - "c4": { |
49 |
| - "name": "allenai/c4", |
50 |
| - "config": "en", |
51 |
| - "split": "validation", # Use validation for C4's test split |
52 |
| - "text_column": "text", |
53 |
| - }, |
54 |
| - } |
55 |
| - |
56 |
| - if dataset_identifier not in DATASET_CONFIGS: |
57 |
| - raise ValueError( |
58 |
| - f"Unknown dataset identifier '{dataset_identifier}'. " |
59 |
| - f"Available options are: {list(DATASET_CONFIGS.keys())}" |
60 |
| - ) |
61 |
| - |
62 |
| - config = DATASET_CONFIGS[dataset_identifier] |
63 |
| - |
64 |
| - if dataset_identifier == "ptb": |
65 |
| - url = "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz" |
66 |
| - try: |
67 |
| - # 1. Download the archive into memory |
68 |
| - response = requests.get(url) |
69 |
| - response.raise_for_status() |
70 |
| - |
71 |
| - # 2. Extract only the test file from the in-memory archive |
72 |
| - with tarfile.open( |
73 |
| - fileobj=io.BytesIO(response.content), mode="r:gz" |
74 |
| - ) as tar: |
75 |
| - test_path = "./simple-examples/data/ptb.test.txt" |
76 |
| - test_bytes = tar.extractfile(test_path).read() |
77 |
| - |
78 |
| - # 3. Decode the bytes and join into a single string |
79 |
| - test_lines = test_bytes.decode("utf-8").strip().split("\n") |
80 |
| - all_text = "\n\n".join(test_lines) |
81 |
| - |
82 |
| - print("✅ Successfully processed PTB test data.") |
83 |
| - return all_text |
84 |
| - |
85 |
| - except Exception as e: |
86 |
| - print(f"Failed to download or process PTB data: {e!r}") |
87 |
| - raise e |
88 |
| - |
89 |
| - load_kwargs = {"name": config["config"]} |
90 |
| - |
91 |
| - if dataset_identifier == "c4": |
92 |
| - load_kwargs["streaming"] = True |
93 |
| - # For PTB, force a redownload to bypass potential cache errors. |
94 |
| - if dataset_identifier == "ptb": |
95 |
| - load_kwargs["download_mode"] = "force_redownload" |
96 |
| - |
97 |
| - print(f"Loading dataset '{config['name']}'...") |
98 |
| - |
99 |
| - test_data = load_dataset( |
100 |
| - config["name"], split=config["split"], **load_kwargs |
101 |
| - ) |
102 |
| - |
103 |
| - if dataset_identifier == "c4": |
104 |
| - print(f" -> Limiting C4 to the first {nsamples} documents forspeed.") |
105 |
| - test_data = test_data.take(nsamples) |
106 |
| - |
107 |
| - all_text = "\n\n".join( |
108 |
| - row[config["text_column"]] |
109 |
| - for row in test_data |
110 |
| - if row.get(config["text_column"]) |
111 |
| - ) |
112 |
| - |
113 |
| - print(f"Successfully loaded and processed {dataset_identifier}.") |
114 |
| - return all_text |
115 |
| - |
116 |
| - |
117 | 26 | def _get_model():
|
118 | 27 | input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
119 | 28 | input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
|
0 commit comments