Skip to content

Commit db1d479

Browse files
authored
[fit] optimize test framework to support executing specific actions at test startup (#205)
1 parent 62d49ba commit db1d479

File tree

9 files changed

+184
-58
lines changed

9 files changed

+184
-58
lines changed

framework/fit/java/fit-test/fit-test-framework/src/main/java/modelengine/fitframework/test/annotation/Sql.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,38 @@
1515
* 用于在测试用例前执行初始化 Sql 语句。
1616
*
1717
* @author 易文渊
18+
* @author 季聿阶
1819
* @since 2024-07-21
1920
*/
2021
@Target({ElementType.TYPE, ElementType.METHOD})
2122
@Retention(RetentionPolicy.RUNTIME)
2223
public @interface Sql {
2324
/**
24-
* 获取 sql 脚本文件路径。
25+
* 获取前置 SQL 脚本文件路径。
2526
*
26-
* @return 表示 sql 脚本文件路径集合的 {@code String[]}。
27+
* @return 表示前置 SQL 脚本文件路径集合的 {@link String}{@code []}。
2728
*/
28-
String[] scripts();
29+
String[] before() default {};
30+
31+
/**
32+
* 获取后置 SQL 脚本文件路径。
33+
*
34+
* @return 表示后置 SQL 脚本文件路径集合的 {@link String}{@code []}。
35+
*/
36+
String[] after() default {};
37+
38+
/**
39+
* 获取 SQL 脚本执行位置。
40+
*/
41+
enum Position {
42+
/**
43+
* 在前置执行。
44+
*/
45+
BEFORE,
46+
47+
/**
48+
* 在后置执行。
49+
*/
50+
AFTER
51+
}
2952
}

framework/fit/java/fit-test/fit-test-framework/src/main/java/modelengine/fitframework/test/domain/TestPlugin.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import java.lang.reflect.Field;
2828
import java.util.Arrays;
2929
import java.util.List;
30+
import java.util.Map;
3031
import java.util.Set;
32+
import java.util.function.Supplier;
3133
import java.util.stream.Collectors;
3234
import java.util.stream.Stream;
3335

@@ -64,7 +66,7 @@ public TestPlugin(FitRuntime runtime, TestContextConfiguration configuration) {
6466
Validation.notNull(configuration, "The configuration to create test plugin cannot be null.");
6567
this.packageScanner = this.scanner((packageScanner, clazz) -> this.onClassDetected(packageScanner, clazz,
6668
// 包含的类已经提前注册,因此需要将包含的和排除的类进行合并。
67-
Stream.concat(Arrays.stream(this.configuration.includeClasses()),
69+
Stream.concat(this.configuration.includeClasses().keySet().stream(),
6870
Arrays.stream(this.configuration.excludeClasses())).collect(Collectors.toSet())));
6971
}
7072

@@ -93,6 +95,7 @@ protected void registerSystemBeans() {
9395
@Override
9496
protected void scanBeans() {
9597
this.registerBeans(this.configuration.includeClasses());
98+
this.configuration.actions().forEach(action -> action.accept(this));
9699
this.scan(this.configuration.scannedPackages());
97100
this.registerMockedBeans(this.configuration.mockedBeanFields());
98101
}
@@ -111,10 +114,22 @@ private void onClassDetected(PackageScanner scanner, Class<?> clazz, Set<Class<?
111114
}
112115
}
113116

114-
private void registerBeans(Class<?>[] classArray) {
115-
Arrays.stream(classArray)
116-
.filter(clazz -> !this.container().lookup(clazz).isPresent())
117-
.forEach(clazz -> this.container().registry().register(clazz));
117+
private void registerBeans(Map<Class<?>, Supplier<Object>> classes) {
118+
classes.entrySet()
119+
.stream()
120+
.filter(entry -> this.container().lookup(entry.getKey()).isEmpty())
121+
.forEach(entry -> {
122+
if (entry.getValue() == null) {
123+
this.container().registry().register(entry.getKey());
124+
} else {
125+
Object bean = entry.getValue().get();
126+
if (bean == null) {
127+
this.container().registry().register(entry.getKey());
128+
} else {
129+
this.container().registry().register(bean);
130+
}
131+
}
132+
});
118133
}
119134

120135
private void scan(Set<String> basePackages) {

framework/fit/java/fit-test/fit-test-framework/src/main/java/modelengine/fitframework/test/domain/listener/DataSourceListener.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,39 @@
66

77
package modelengine.fitframework.test.domain.listener;
88

9-
import modelengine.fitframework.ioc.BeanContainer;
10-
import modelengine.fitframework.ioc.BeanNotFoundException;
119
import modelengine.fitframework.test.annotation.EnableDataSource;
12-
import modelengine.fitframework.test.domain.TestContext;
10+
import modelengine.fitframework.test.domain.resolver.TestContextConfiguration;
1311
import modelengine.fitframework.test.domain.util.AnnotationUtils;
12+
import modelengine.fitframework.util.MapBuilder;
1413

1514
import org.h2.jdbcx.JdbcConnectionPool;
1615

1716
import java.util.Optional;
17+
import java.util.function.Supplier;
1818

1919
import javax.sql.DataSource;
2020

2121
/**
2222
* 用于注入 dataSource 的监听器。
2323
*
2424
* @author 易文渊
25+
* @author 季聿阶
2526
* @since 2024-07-21
2627
*/
2728
public class DataSourceListener implements TestListener {
2829
@Override
29-
public void beforeTestClass(TestContext context) {
30-
Class<?> clazz = context.testClass();
30+
public Optional<TestContextConfiguration> config(Class<?> clazz) {
3131
Optional<EnableDataSource> annotationOption = AnnotationUtils.getAnnotation(clazz, EnableDataSource.class);
32-
if (!annotationOption.isPresent()) {
33-
return;
34-
}
35-
BeanContainer beanContainer = context.plugin().container();
36-
try {
37-
beanContainer.beans().get(DataSource.class);
38-
} catch (BeanNotFoundException e) {
39-
EnableDataSource enableDataSource = annotationOption.get();
40-
DataSource dataSource = JdbcConnectionPool.create(enableDataSource.model().getUrl(), "sa", "sa");
41-
beanContainer.registry().register(dataSource);
32+
if (annotationOption.isEmpty()) {
33+
return Optional.empty();
4234
}
35+
TestContextConfiguration customConfig = TestContextConfiguration.custom()
36+
.testClass(clazz)
37+
.includeClasses(MapBuilder.<Class<?>, Supplier<Object>>get().put(DataSource.class, () -> {
38+
EnableDataSource enableDataSource = annotationOption.get();
39+
return JdbcConnectionPool.create(enableDataSource.model().getUrl(), "sa", "sa");
40+
}).build())
41+
.build();
42+
return Optional.of(customConfig);
4343
}
4444
}

framework/fit/java/fit-test/fit-test-framework/src/main/java/modelengine/fitframework/test/domain/listener/MockMvcListener.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import modelengine.fitframework.test.domain.mvc.request.MockRequestBuilder;
1818
import modelengine.fitframework.test.domain.resolver.TestContextConfiguration;
1919
import modelengine.fitframework.test.domain.util.AnnotationUtils;
20+
import modelengine.fitframework.util.MapBuilder;
2021
import modelengine.fitframework.util.StringUtils;
2122
import modelengine.fitframework.util.ThreadUtils;
2223

@@ -26,6 +27,7 @@
2627
import java.util.Objects;
2728
import java.util.Optional;
2829
import java.util.Set;
30+
import java.util.function.Supplier;
2931

3032
/**
3133
* 用于注入 mockMvc 的监听器。
@@ -50,12 +52,12 @@ public MockMvcListener(int port) {
5052

5153
@Override
5254
public Optional<TestContextConfiguration> config(Class<?> clazz) {
53-
if (!AnnotationUtils.getAnnotation(clazz, EnableMockMvc.class).isPresent()) {
55+
if (AnnotationUtils.getAnnotation(clazz, EnableMockMvc.class).isEmpty()) {
5456
return Optional.empty();
5557
}
5658
TestContextConfiguration configuration = TestContextConfiguration.custom()
5759
.testClass(clazz)
58-
.includeClasses(new Class[] {MockController.class})
60+
.includeClasses(MapBuilder.<Class<?>, Supplier<Object>>get().put(MockController.class, null).build())
5961
.scannedPackages(DEFAULT_SCAN_PACKAGES)
6062
.build();
6163
return Optional.of(configuration);
@@ -64,7 +66,7 @@ public Optional<TestContextConfiguration> config(Class<?> clazz) {
6466
@Override
6567
public void beforeTestClass(TestContext context) {
6668
Class<?> testClass = context.testClass();
67-
if (!AnnotationUtils.getAnnotation(testClass, EnableMockMvc.class).isPresent()) {
69+
if (AnnotationUtils.getAnnotation(testClass, EnableMockMvc.class).isEmpty()) {
6870
return;
6971
}
7072
MockMvc mockMvc = new MockMvc(this.port);

framework/fit/java/fit-test/fit-test-framework/src/main/java/modelengine/fitframework/test/domain/listener/SqlExecuteListener.java

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,26 @@
66

77
package modelengine.fitframework.test.domain.listener;
88

9+
import modelengine.fitframework.plugin.Plugin;
910
import modelengine.fitframework.test.annotation.Sql;
1011
import modelengine.fitframework.test.domain.TestContext;
12+
import modelengine.fitframework.test.domain.resolver.TestContextConfiguration;
1113
import modelengine.fitframework.util.IoUtils;
1214

1315
import java.io.IOException;
1416
import java.lang.reflect.Method;
1517
import java.sql.Connection;
1618
import java.sql.SQLException;
19+
import java.util.List;
20+
import java.util.Optional;
1721

1822
import javax.sql.DataSource;
1923

2024
/**
2125
* 用于执行 SQL 脚本。
2226
*
2327
* @author 易文渊
28+
* @author 季聿阶
2429
* @since 2024-07-21
2530
*/
2631
public class SqlExecuteListener implements TestListener {
@@ -29,34 +34,66 @@ public class SqlExecuteListener implements TestListener {
2934
private Sql globalSql;
3035

3136
@Override
32-
public void beforeTestClass(TestContext context) {
33-
Class<?> testClass = context.testClass();
34-
this.globalSql = testClass.getAnnotation(Sql.class);
37+
public Optional<TestContextConfiguration> config(Class<?> clazz) {
38+
this.globalSql = clazz.getAnnotation(Sql.class);
39+
if (this.globalSql == null) {
40+
return Optional.empty();
41+
}
42+
TestContextConfiguration configuration =
43+
TestContextConfiguration.custom().testClass(clazz).actions(List.of(this::executeAction)).build();
44+
return Optional.of(configuration);
45+
}
46+
47+
private void executeAction(Plugin plugin) {
48+
if (this.globalSql == null) {
49+
return;
50+
}
51+
executeSql(plugin, this.globalSql, Sql.Position.BEFORE);
3552
}
3653

3754
@Override
3855
public void beforeTestMethod(TestContext context) {
39-
execSql(globalSql, context);
40-
execMethodSql(context);
56+
execMethodSql(context, Sql.Position.BEFORE);
57+
}
58+
59+
@Override
60+
public void afterTestMethod(TestContext context) {
61+
execMethodSql(context, Sql.Position.AFTER);
4162
}
4263

43-
private static void execMethodSql(TestContext context) {
64+
private static void execMethodSql(TestContext context, Sql.Position position) {
4465
Method method = context.testMethod();
4566
Sql sql = method.getAnnotation(Sql.class);
46-
execSql(sql, context);
67+
executeSql(context.plugin(), sql, position);
68+
}
69+
70+
@Override
71+
public void afterTestClass(TestContext context) {
72+
Class<?> testClass = context.testClass();
73+
Sql sql = testClass.getAnnotation(Sql.class);
74+
executeSql(context.plugin(), sql, Sql.Position.AFTER);
4775
}
4876

49-
private static void execSql(Sql sql, TestContext context) {
77+
private static void executeSql(Plugin plugin, Sql sql, Sql.Position position) {
5078
if (sql == null) {
5179
return;
5280
}
53-
DataSource dataSource = context.plugin().container().beans().get(DataSource.class);
81+
DataSource dataSource = plugin.container().beans().get(DataSource.class);
5482
try (Connection connection = dataSource.getConnection()) {
55-
for (String script : sql.scripts()) {
83+
String[] scripts = getScripts(sql, position);
84+
for (String script : scripts) {
5685
connection.createStatement().execute(IoUtils.content(CLASS_LOADER, script));
5786
}
5887
} catch (SQLException | IOException e) {
59-
throw new IllegalStateException("Fail to execute sql.", e);
88+
throw new IllegalStateException("Failed to execute sql.", e);
89+
}
90+
}
91+
92+
private static String[] getScripts(Sql sql, Sql.Position position) {
93+
if (position == Sql.Position.BEFORE) {
94+
return sql.before();
95+
} else {
96+
return sql.after();
6097
}
6198
}
6299
}

framework/fit/java/fit-test/fit-test-framework/src/main/java/modelengine/fitframework/test/domain/resolver/DefaultTestClassResolver.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import java.util.HashSet;
2222
import java.util.Optional;
2323
import java.util.Set;
24+
import java.util.function.Function;
2425
import java.util.stream.Collectors;
26+
import java.util.stream.Stream;
2527

2628
/**
2729
* 默认的单测类解析器。
@@ -30,8 +32,7 @@
3032
* @since 2023-01-17
3133
*/
3234
public class DefaultTestClassResolver implements TestClassResolver {
33-
private static final Set<String> DEFAULT_SCAN_PACKAGES = new HashSet<>(Arrays.asList(
34-
"modelengine.fit.value",
35+
private static final Set<String> DEFAULT_SCAN_PACKAGES = new HashSet<>(Arrays.asList("modelengine.fit.value",
3536
"modelengine.fit.serialization",
3637
"modelengine.fitframework.validation"));
3738

@@ -41,7 +42,8 @@ public TestContextConfiguration resolve(Class<?> clazz) {
4142
Class<?>[] includeClasses = this.resolveIncludeClasses(testConfigurationClass);
4243
return TestContextConfiguration.custom()
4344
.testClass(clazz)
44-
.includeClasses(includeClasses)
45+
.includeClasses(Stream.of(includeClasses)
46+
.collect(Collectors.toMap(Function.identity(), key -> () -> null)))
4547
.excludeClasses(this.resolveExcludeClasses(clazz))
4648
.scannedPackages(this.scanBeans(includeClasses))
4749
.mockedBeanFields(this.scanMockBeansFieldSet(clazz))
@@ -86,7 +88,7 @@ private Set<String> scanBeans(Class<?>[] classes) {
8688

8789
private Set<String> getBasePackages(Class<?> clazz) {
8890
Optional<ScanPackages> opScanPackagesAnnotation = AnnotationUtils.getAnnotation(clazz, ScanPackages.class);
89-
if (!opScanPackagesAnnotation.isPresent()) {
91+
if (opScanPackagesAnnotation.isEmpty()) {
9092
return new HashSet<>();
9193
}
9294
Set<String> basePackages = new HashSet<>(Arrays.asList(opScanPackagesAnnotation.get().value()));

0 commit comments

Comments
 (0)