Skip to content

Commit 58d9e8f

Browse files
authored
Update main_double-dispatch.R
1 parent d700e7f commit 58d9e8f

File tree

1 file changed

+5
-30
lines changed

1 file changed

+5
-30
lines changed

scripts/main_double-dispatch.R

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,14 @@ library(dplyr)
1010
# 1. Define the DecisionTree and MarkovModel classes
1111
###############################
1212

13-
# DecisionTree constructors and methods
1413
DecisionTree <- function(data, ...) {
15-
UseMethod("DecisionTree")
16-
}
17-
18-
DecisionTree.default <- function(data, ...) {
1914
call_obj <- match.call()
2015
structure(list(data = data),
2116
call = call_obj,
2217
class = c("DecisionTree", "Model"))
2318
}
2419

25-
# MarkovModel constructors and methods
26-
MarkovModel <- function(model, ...) {
27-
UseMethod("MarkovModel")
28-
}
29-
30-
MarkovModel.default <- function(model = NA,
20+
MarkovModel <- function(model = NA,
3121
init_probs = NA,
3222
trans_matrix = NA,
3323
cost_matrix = NA,
@@ -58,21 +48,6 @@ CombinedModel <- function(...) {
5848
structure(models, class = "CombinedModel")
5949
}
6050

61-
# Infix operator for chaining (optional)
62-
`%->%` <- function(mod1, mod2) {
63-
UseMethod("%->%")
64-
}
65-
`%->%.default` <- function(mod1, mod2) {
66-
if (!inherits(mod2, "Model")) {
67-
stop("All arguments must be of class 'Model'")
68-
} else {
69-
stop("No method for this model")
70-
}
71-
}
72-
`%->%.Model` <- function(mod1, mod2) {
73-
CombinedModel(mod1, mod2)
74-
}
75-
7651
###############################
7752
# 3. Define the runner functions for each model
7853
###############################
@@ -112,8 +87,8 @@ run_model.MarkovModel <- function(model) {
11287
class = c("MarkovModelOutput", "output", class(model)))
11388
}
11489

115-
# Run a CombinedModel: update each model in the chain (using update_model())
116-
# and store each model's output in a list.
90+
# update each model in the chain (using update_model())
91+
# and store each model's output in a list
11792
run_model.CombinedModel <- function(model_chain) {
11893
results <- list()
11994

@@ -122,14 +97,14 @@ run_model.CombinedModel <- function(model_chain) {
12297

12398
if (i > 1) {
12499
# Update current_model based on the output of the previous model
125-
# The update_model() S3 function inspects the previous result’s class
100+
# S3 function inspects the previous result’s class
126101
current_model <- update_model(current_model, results[[i - 1]])
127102
}
128103

129104
results[[i]] <- run_model(current_model)
130105
}
131106

132-
structure(results, class = c("CombinedModelOutput", "output"))
107+
structure(results, class = c("CombinedModelOutput", class(results)))
133108
}
134109

135110
###############################

0 commit comments

Comments
 (0)