Skip to content

Commit bbca39f

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent fb835fa commit bbca39f

File tree

1 file changed

+84
-48
lines changed

1 file changed

+84
-48
lines changed

tests/test-backend-ops.cpp

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,28 @@ enum class test_status_t {
449449
FAIL
450450
};
451451

452+
// Forward declarations for the visitor pattern
453+
struct message_visitor;
454+
455+
// Base class for all message types that can be printed
456+
struct message_data {
457+
virtual ~message_data() {}
458+
virtual void accept(message_visitor& visitor) const = 0;
459+
};
460+
461+
// Message visitor interface
462+
struct message_visitor {
463+
virtual ~message_visitor() {}
464+
virtual void visit(const struct test_operation_info& info) = 0;
465+
virtual void visit(const struct test_summary_info& info) = 0;
466+
virtual void visit(const struct testing_start_info& info) = 0;
467+
virtual void visit(const struct backend_init_info& info) = 0;
468+
virtual void visit(const struct backend_status_info& info) = 0;
469+
virtual void visit(const struct overall_summary_info& info) = 0;
470+
};
471+
452472
// Printer classes for different output formats
453-
struct test_operation_info {
473+
struct test_operation_info : public message_data {
454474
std::string op_name;
455475
std::string op_params;
456476
std::string backend_name;
@@ -482,6 +502,10 @@ struct test_operation_info {
482502
test_status_t status = test_status_t::OK, const std::string& failure_reason = "")
483503
: op_name(op_name), op_params(op_params), backend_name(backend_name), status(status), failure_reason(failure_reason) {}
484504

505+
void accept(message_visitor& visitor) const override {
506+
visitor.visit(*this);
507+
}
508+
485509
// Set error information
486510
void set_error(const std::string& component, const std::string& details) {
487511
has_error = true;
@@ -527,24 +551,32 @@ struct test_operation_info {
527551
}
528552
};
529553

530-
struct test_summary_info {
554+
struct test_summary_info : public message_data {
531555
size_t tests_passed;
532556
size_t tests_total;
533557
bool is_backend_summary = false; // true for backend summary, false for test summary
534558

535559
test_summary_info() = default;
536560
test_summary_info(size_t tests_passed, size_t tests_total, bool is_backend_summary = false)
537561
: tests_passed(tests_passed), tests_total(tests_total), is_backend_summary(is_backend_summary) {}
562+
563+
void accept(message_visitor& visitor) const override {
564+
visitor.visit(*this);
565+
}
538566
};
539567

540-
struct testing_start_info {
568+
struct testing_start_info : public message_data {
541569
size_t device_count;
542570

543571
testing_start_info() = default;
544572
testing_start_info(size_t device_count) : device_count(device_count) {}
573+
574+
void accept(message_visitor& visitor) const override {
575+
visitor.visit(*this);
576+
}
545577
};
546578

547-
struct backend_init_info {
579+
struct backend_init_info : public message_data {
548580
size_t device_index;
549581
size_t total_devices;
550582
std::string device_name;
@@ -562,41 +594,49 @@ struct backend_init_info {
562594
: device_index(device_index), total_devices(total_devices), device_name(device_name), skipped(skipped),
563595
skip_reason(skip_reason), description(description), memory_total_mb(memory_total_mb),
564596
memory_free_mb(memory_free_mb), has_memory_info(has_memory_info) {}
597+
598+
void accept(message_visitor& visitor) const override {
599+
visitor.visit(*this);
600+
}
565601
};
566602

567-
struct backend_status_info {
603+
struct backend_status_info : public message_data {
568604
std::string backend_name;
569605
test_status_t status;
570606

571607
backend_status_info() = default;
572608
backend_status_info(const std::string& backend_name, test_status_t status)
573609
: backend_name(backend_name), status(status) {}
610+
611+
void accept(message_visitor& visitor) const override {
612+
visitor.visit(*this);
613+
}
574614
};
575615

576-
struct overall_summary_info {
616+
struct overall_summary_info : public message_data {
577617
size_t backends_passed;
578618
size_t backends_total;
579619
bool all_passed;
580620

581621
overall_summary_info() = default;
582622
overall_summary_info(size_t backends_passed, size_t backends_total, bool all_passed)
583623
: backends_passed(backends_passed), backends_total(backends_total), all_passed(all_passed) {}
624+
625+
void accept(message_visitor& visitor) const override {
626+
visitor.visit(*this);
627+
}
584628
};
585629

586-
struct printer {
630+
struct printer : public message_visitor {
587631
virtual ~printer() {}
588632
FILE * fout = stdout;
589633
virtual void print_header() {}
590634
virtual void print_test_result(const test_result & result) = 0;
591635
virtual void print_footer() {}
592636

593-
template<typename T>
594-
void print_message(const T& data) {
595-
print_message_impl(&data, typeid(T).name());
637+
void print_message(const message_data& data) {
638+
data.accept(*this);
596639
}
597-
598-
protected:
599-
virtual void print_message_impl(const void* data, const char* type_name) = 0;
600640
};
601641

602642
struct console_printer : public printer {
@@ -608,29 +648,8 @@ struct console_printer : public printer {
608648
}
609649
}
610650

611-
protected:
612-
void print_message_impl(const void* data, const char* type_name) override {
613-
std::string type_str(type_name);
614-
615-
if (type_str.find("test_operation_info") != std::string::npos) {
616-
handle_message(*static_cast<const test_operation_info*>(data));
617-
} else if (type_str.find("test_summary_info") != std::string::npos) {
618-
handle_message(*static_cast<const test_summary_info*>(data));
619-
} else if (type_str.find("testing_start_info") != std::string::npos) {
620-
handle_message(*static_cast<const testing_start_info*>(data));
621-
} else if (type_str.find("backend_init_info") != std::string::npos) {
622-
handle_message(*static_cast<const backend_init_info*>(data));
623-
} else if (type_str.find("backend_status_info") != std::string::npos) {
624-
handle_message(*static_cast<const backend_status_info*>(data));
625-
} else if (type_str.find("overall_summary_info") != std::string::npos) {
626-
handle_message(*static_cast<const overall_summary_info*>(data));
627-
} else {
628-
GGML_ABORT("unknown message type: %s", type_name);
629-
}
630-
}
631-
632-
private:
633-
void handle_message(const test_operation_info& info) {
651+
// Visitor pattern implementations
652+
void visit(const test_operation_info& info) override {
634653
printf(" %s(%s): ", info.op_name.c_str(), info.op_params.c_str());
635654
fflush(stdout);
636655

@@ -685,15 +704,15 @@ struct console_printer : public printer {
685704
}
686705
}
687706

688-
void handle_message(const test_summary_info& info) {
707+
void visit(const test_summary_info& info) override {
689708
if (info.is_backend_summary) {
690709
printf("%zu/%zu backends passed\n", info.tests_passed, info.tests_total);
691710
} else {
692711
printf(" %zu/%zu tests passed\n", info.tests_passed, info.tests_total);
693712
}
694713
}
695714

696-
void handle_message(const backend_status_info& info) {
715+
void visit(const backend_status_info& info) override {
697716
printf(" Backend %s: ", info.backend_name.c_str());
698717
if (info.status == test_status_t::OK) {
699718
printf("\033[1;32mOK\033[0m\n");
@@ -702,11 +721,11 @@ struct console_printer : public printer {
702721
}
703722
}
704723

705-
void handle_message(const testing_start_info& info) {
724+
void visit(const testing_start_info& info) override {
706725
printf("Testing %zu devices\n\n", info.device_count);
707726
}
708727

709-
void handle_message(const backend_init_info& info) {
728+
void visit(const backend_init_info& info) override {
710729
printf("Backend %zu/%zu: %s\n", info.device_index + 1, info.total_devices, info.device_name.c_str());
711730

712731
if (info.skipped) {
@@ -725,7 +744,7 @@ struct console_printer : public printer {
725744
printf("\n");
726745
}
727746

728-
void handle_message(const overall_summary_info& info) {
747+
void visit(const overall_summary_info& info) override {
729748
printf("%zu/%zu backends passed\n", info.backends_passed, info.backends_total);
730749
if (info.all_passed) {
731750
printf("\033[1;32mOK\033[0m\n");
@@ -734,6 +753,7 @@ struct console_printer : public printer {
734753
}
735754
}
736755

756+
private:
737757
void print_test_console(const test_result & result) {
738758
printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
739759
fflush(stdout);
@@ -838,13 +858,29 @@ struct sql_printer : public printer {
838858
fprintf(fout, ");\n");
839859
}
840860

841-
protected:
842-
// Implementation that checks type and dispatches to appropriate handler
843-
void print_message_impl(const void* data, const char* type_name) override {
844-
// SQL printer doesn't need to handle message types for now
845-
// All necessary output is handled through print_test_result
846-
(void)data;
847-
(void)type_name;
861+
// Visitor pattern implementations - SQL printer doesn't need to handle message types for now
862+
void visit(const test_operation_info& info) override {
863+
(void)info;
864+
}
865+
866+
void visit(const test_summary_info& info) override {
867+
(void)info;
868+
}
869+
870+
void visit(const testing_start_info& info) override {
871+
(void)info;
872+
}
873+
874+
void visit(const backend_init_info& info) override {
875+
(void)info;
876+
}
877+
878+
void visit(const backend_status_info& info) override {
879+
(void)info;
880+
}
881+
882+
void visit(const overall_summary_info& info) override {
883+
(void)info;
848884
}
849885
};
850886

0 commit comments

Comments
 (0)