diff --git a/src/main/java/com/xkcoding/justauth/AuthRequestFactory.java b/src/main/java/com/xkcoding/justauth/AuthRequestFactory.java index 7f3a068..2e5e824 100644 --- a/src/main/java/com/xkcoding/justauth/AuthRequestFactory.java +++ b/src/main/java/com/xkcoding/justauth/AuthRequestFactory.java @@ -32,7 +32,8 @@ import me.zhyd.oauth.config.AuthSource; import me.zhyd.oauth.enums.AuthResponseStatus; import me.zhyd.oauth.exception.AuthException; -import me.zhyd.oauth.request.*; +import me.zhyd.oauth.request.AuthDefaultRequest; +import me.zhyd.oauth.request.AuthRequest; import org.springframework.util.CollectionUtils; import java.net.InetSocketAddress; @@ -178,73 +179,102 @@ private AuthRequest getDefaultRequest(String source) { configureHttpConfig(authDefaultSource.name(), config, properties.getHttpConfig()); switch (authDefaultSource) { - case GITHUB: - return new AuthGithubRequest(config, authStateCache); - case WEIBO: - return new AuthWeiboRequest(config, authStateCache); - case GITEE: - return new AuthGiteeRequest(config, authStateCache); - case DINGTALK: - return new AuthDingTalkRequest(config, authStateCache); - case BAIDU: - return new AuthBaiduRequest(config, authStateCache); + case FEISHU: case CSDN: - return new AuthCsdnRequest(config, authStateCache); - case CODING: - return new AuthCodingRequest(config, authStateCache); - case OSCHINA: - return new AuthOschinaRequest(config, authStateCache); - case ALIPAY: - return new AuthAlipayRequest(config, authStateCache); - case QQ: - return new AuthQqRequest(config, authStateCache); - case WECHAT_MP: - return new AuthWeChatMpRequest(config, authStateCache); - case WECHAT_OPEN: - return new AuthWeChatOpenRequest(config, authStateCache); - case WECHAT_ENTERPRISE: - return new AuthWeChatEnterpriseRequest(config, authStateCache); - case TAOBAO: - return new AuthTaobaoRequest(config, authStateCache); - case GOOGLE: - return new AuthGoogleRequest(config, authStateCache); - case FACEBOOK: - return new AuthFacebookRequest(config, authStateCache); - case DOUYIN: - return new AuthDouyinRequest(config, authStateCache); - case LINKEDIN: - return new AuthLinkedinRequest(config, authStateCache); - case MICROSOFT: - return new AuthMicrosoftRequest(config, authStateCache); - case MI: - return new AuthMiRequest(config, authStateCache); - case TOUTIAO: - return new AuthToutiaoRequest(config, authStateCache); - case TEAMBITION: - return new AuthTeambitionRequest(config, authStateCache); - case RENREN: - return new AuthRenrenRequest(config, authStateCache); - case PINTEREST: - return new AuthPinterestRequest(config, authStateCache); - case STACK_OVERFLOW: - return new AuthStackOverflowRequest(config, authStateCache); - case HUAWEI: - return new AuthHuaweiRequest(config, authStateCache); - case GITLAB: - return new AuthGitlabRequest(config, authStateCache); - case KUJIALE: - return new AuthKujialeRequest(config, authStateCache); - case ELEME: - return new AuthElemeRequest(config, authStateCache); - case MEITUAN: - return new AuthMeituanRequest(config, authStateCache); - case TWITTER: - return new AuthTwitterRequest(config, authStateCache); - default: return null; + default: + } + + return getAuthDefaultRequest(config, authDefaultSource, authStateCache); + } + + /** + * 获取 {@link AuthDefaultRequest} 的适配器 + * + * @param config {@link AuthDefaultRequest} 的 {@link AuthConfig} + * @param source {@link AuthDefaultRequest} 的 {@link AuthSource} + * @param authStateCache {@link AuthDefaultRequest} 的 {@link AuthStateCache} + * @return {@link AuthDefaultRequest} 相对应的适配器 + */ + private AuthDefaultRequest getAuthDefaultRequest(AuthConfig config, + AuthDefaultSource source, + AuthStateCache authStateCache) { + + Object[] arguments = new Object[]{config, authStateCache}; + + final Class authDefaultRequestClass = getAuthRequestClassBySource(source); + + if (!AuthDefaultRequest.class.isAssignableFrom(authDefaultRequestClass)) { + throw new RuntimeException(authDefaultRequestClass.getName() + " Must be a subclass of me.zhyd.oauth.request.AuthDefaultRequest"); + } + + return (AuthDefaultRequest) ReflectUtil.newInstance(authDefaultRequestClass, arguments); + + } + + /** + * {@link AuthDefaultRequest} 子类的包名 + */ + public static final String AUTH_REQUEST_PACKAGE = "me.zhyd.oauth.request."; + /** + * {@link AuthDefaultRequest} 子类类名前缀 + */ + public static final String AUTH_REQUEST_PREFIX = "Auth"; + /** + * {@link AuthDefaultRequest} 子类类名后缀 + */ + public static final String AUTH_REQUEST_SUFFIX = "Request"; + /** + * {@link AuthDefaultSource} 枚举名称分隔符 + */ + public static final String SEPARATOR = "_"; + + /** + * 根据 {@link AuthDefaultSource} 获取对应的 {@link AuthDefaultRequest} 子类的 Class + * @param source {@link AuthDefaultSource} + * @return 返回 {@link AuthDefaultSource} 对应的 {@link AuthDefaultRequest} 子类的 Class + */ + public static Class getAuthRequestClassBySource(AuthDefaultSource source) { + String[] splits = source.name().split(SEPARATOR); + String authRequestClassName = AUTH_REQUEST_PACKAGE + toAuthRequestClassName(splits); + try { + return Class.forName(authRequestClassName); + } + catch (Exception e) { + throw new RuntimeException(e.getMessage(), e); } } + /** + * 根据传入的字符串数组转换为类名格式的字符串, 另外 DingTalk -> DingTalk, WECHAT -> WeChat. + * @param splits 字符串数组, 例如: [WECHAT, OPEN] + * @return 返回类名格式的字符串, 如传入的数组是: [STACK, OVERFLOW] 那么返回 AuthStackOverflowRequest + */ + private static String toAuthRequestClassName(String[] splits) { + StringBuilder sb = new StringBuilder(); + sb.append(AUTH_REQUEST_PREFIX); + for (String split : splits) { + split = split.toLowerCase(); + if (AuthDefaultSource.DINGTALK.name().equalsIgnoreCase(split)) { + sb.append("DingTalk"); + continue; + } + if ("wechat".equalsIgnoreCase(split)) { + sb.append("WeChat"); + continue; + } + if (split.length() > 1) { + sb.append(split.substring(0, 1).toUpperCase()).append(split.substring(1)); + } + else { + sb.append(split.toUpperCase()); + } + } + sb.append(AUTH_REQUEST_SUFFIX); + return sb.toString(); + } + + /** * 配置 http 相关的配置 *