Skip to content

Commit c294784

Browse files
committed
rough IFEval implementation using llm_instruction benchmark
1 parent 30b6464 commit c294784

File tree

12 files changed

+1660
-4
lines changed

12 files changed

+1660
-4
lines changed

flutter/cpp/binary/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ cc_binary(
5454
"//flutter/cpp/datasets:ade20k",
5555
"//flutter/cpp/datasets:coco",
5656
"//flutter/cpp/datasets:coco_gen",
57+
"//flutter/cpp/datasets:ifeval",
5758
"//flutter/cpp/datasets:imagenet",
5859
"//flutter/cpp/datasets:mmlu_gen",
5960
"//flutter/cpp/datasets:snu_sr",

flutter/cpp/binary/main.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "flutter/cpp/datasets/ade20k.h"
2626
#include "flutter/cpp/datasets/coco.h"
2727
#include "flutter/cpp/datasets/coco_gen.h"
28+
#include "flutter/cpp/datasets/ifeval.h"
2829
#include "flutter/cpp/datasets/imagenet.h"
2930
#include "flutter/cpp/datasets/mmlu_gen.h"
3031
#include "flutter/cpp/datasets/snu_sr.h"
@@ -70,6 +71,8 @@ DatasetConfig::DatasetType Str2DatasetType(absl::string_view name) {
7071
return DatasetConfig::COCOGEN;
7172
} else if (absl::EqualsIgnoreCase(name, "MMLU")) {
7273
return DatasetConfig::MMLU;
74+
} else if (absl::EqualsIgnoreCase(name, "IFEVAL")) {
75+
return DatasetConfig::IFEVAL;
7376
} else if (absl::EqualsIgnoreCase(name, "DUMMY")) {
7477
return DatasetConfig::NONE;
7578
} else {
@@ -91,6 +94,8 @@ DatasetConfig::DatasetType BenchmarkId2DatasetType(absl::string_view name) {
9194
return DatasetConfig::SNUSR;
9295
} else if (absl::StartsWith(name, "stable_diffusion")) {
9396
return DatasetConfig::COCOGEN;
97+
} else if (absl::StartsWith(name, "llm_instruction")) {
98+
return DatasetConfig::IFEVAL;
9499
} else if (absl::StartsWith(name, "llm")) {
95100
return DatasetConfig::MMLU;
96101
} else {
@@ -420,6 +425,32 @@ int Main(int argc, char *argv[]) {
420425
flag_list.insert(flag_list.end(), dataset_flags.begin(),
421426
dataset_flags.end());
422427
} break;
428+
case DatasetConfig::IFEVAL: {
429+
bool loose_follow = false;
430+
LOG(INFO) << "IFEval dataset for LLM benchmark";
431+
std::string input_tfrecord, sp_path = "";
432+
std::vector<Flag> dataset_flags{
433+
Flag::CreateFlag(
434+
"input_tfrecord", &input_tfrecord,
435+
"Path to the tfrecord file containing inputs for the model.",
436+
Flag::kRequired),
437+
Flag::CreateFlag("sp_path", &sp_path,
438+
"Path to the sentencepiece model file.",
439+
Flag::kRequired),
440+
Flag::CreateFlag("loose-follow", &loose_follow,
441+
"Whether to loosely check if the instructions are "
442+
"being followed"),
443+
};
444+
445+
if (Flags::Parse(&argc, const_cast<const char **>(argv), dataset_flags) &&
446+
backend) {
447+
dataset.reset(
448+
new IFEval(backend.get(), input_tfrecord, sp_path, loose_follow));
449+
}
450+
// Adds to flag_list for showing help.
451+
flag_list.insert(flag_list.end(), dataset_flags.begin(),
452+
dataset_flags.end());
453+
} break;
423454
case DatasetConfig::NONE:
424455
default:
425456
break;

flutter/cpp/datasets/BUILD

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,37 @@ cc_library(
237237
"@org_tensorflow//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
238238
],
239239
)
240+
241+
cc_library(
242+
name = "ifeval",
243+
srcs = [
244+
"ifeval.cc",
245+
],
246+
hdrs = [
247+
"ifeval.h",
248+
"utils.h",
249+
],
250+
copts = tflite_copts() + select({
251+
"//flutter/android/commonlibs:use_asan": [
252+
"-fsanitize=address",
253+
"-g",
254+
"-O1",
255+
"-fno-omit-frame-pointer",
256+
],
257+
"//conditions:default": [],
258+
}),
259+
deps = [
260+
":allocator",
261+
"//flutter/cpp:mlperf_driver",
262+
"//flutter/cpp:utils",
263+
"//flutter/cpp/backends:external",
264+
"//flutter/cpp/datasets/ifeval_utils",
265+
"//flutter/cpp/datasets/mmlu_utils",
266+
"//flutter/cpp/datasets/squad_utils",
267+
"@com_google_absl//absl/container:flat_hash_map",
268+
"@com_google_protobuf//:protobuf",
269+
"@com_google_sentencepiece//:sentencepiece_processor",
270+
"@org_tensorflow//tensorflow/lite/tools/evaluation:utils",
271+
"@org_tensorflow//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
272+
],
273+
)

0 commit comments

Comments
 (0)